Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concurrent map writes in executor #2947

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions executor/predictor/predictor_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,26 @@ var (
)

type PredictorProcess struct {
Ctx context.Context
Client client.SeldonApiClient
Log logr.Logger
ServerUrl *url.URL
Namespace string
Meta *payload.MetaData
Routing map[string]int32
Ctx context.Context
Client client.SeldonApiClient
Log logr.Logger
ServerUrl *url.URL
Namespace string
Meta *payload.MetaData
Routing map[string]int32
RoutingMutex *sync.RWMutex
}

func NewPredictorProcess(context context.Context, client client.SeldonApiClient, log logr.Logger, serverUrl *url.URL, namespace string, meta map[string][]string) PredictorProcess {
return PredictorProcess{
Ctx: context,
Client: client,
Log: log,
ServerUrl: serverUrl,
Namespace: namespace,
Meta: payload.NewFromMap(meta),
Routing: make(map[string]int32),
Ctx: context,
Client: client,
Log: log,
ServerUrl: serverUrl,
Namespace: namespace,
Meta: payload.NewFromMap(meta),
Routing: make(map[string]int32),
RoutingMutex: &sync.RWMutex{},
}
}

Expand Down Expand Up @@ -89,14 +91,18 @@ func (p *PredictorProcess) transformInput(node *v1.PredictiveUnit, msg payload.S
if err != nil {
return nil, err
}
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.Predict(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else if callTransformInput {
msg, err := p.Client.Chain(p.Ctx, node.Name, msg)
if err != nil {
return nil, err
}
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.TransformInput(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else {
return msg, nil
Expand Down Expand Up @@ -189,7 +195,9 @@ func (p *PredictorProcess) aggregate(node *v1.PredictiveUnit, msg []payload.Seld
}

if callClient {
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
return p.Client.Combine(p.Ctx, node.Name, node.Endpoint.ServiceHost, p.getPort(node), msg, p.Meta.Meta)
} else {
return msg[0], nil
Expand All @@ -216,20 +224,26 @@ func (p *PredictorProcess) predictChildren(node *v1.PredictiveUnit, msg payload.
}(i, nodeChild, msg)
}
wg.Wait()
p.RoutingMutex.Lock()
p.Routing[node.Name] = -1
p.RoutingMutex.Unlock()
for i, err := range errs {
if err != nil {
return cmsgs[i], err
}
}
} else if route == -2 {
//Abort and return request
p.RoutingMutex.Lock()
p.Routing[node.Name] = -2
p.RoutingMutex.Unlock()
return msg, nil
} else {
cmsgs = make([]payload.SeldonPayload, 1)
cmsgs[0], err = p.Predict(&node.Children[route], msg)
p.RoutingMutex.Lock()
p.Routing[node.Name] = int32(route)
p.RoutingMutex.Unlock()
if err != nil {
return cmsgs[0], err
}
Expand Down