Skip to content

Commit

Permalink
WIP controller <-> scheduler coordination on server status updates
Browse files Browse the repository at this point in the history
  • Loading branch information
lc525 committed Sep 10, 2024
1 parent 2774ab5 commit 29f3945
Show file tree
Hide file tree
Showing 17 changed files with 1,057 additions and 670 deletions.
1,327 changes: 712 additions & 615 deletions apis/go/mlops/scheduler/scheduler.pb.go

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion apis/mlops/scheduler/scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ message ServerStatusRequest {
/* ServerStatusResponse provides details of current server status
*/
message ServerStatusResponse {
/* Type of SterverStatus update. At the moment the scheduler doesn't combine multiple types of
* updates in the same response. However, the Type enum is forward-compatible with this
* possibility, by setting members to power-of-two values. This means enum values can be used
* as flags and combined with bitwise OR, with the exception of StatusResponseTypeUnknown.
*/
enum Type {
StatusResponseTypeUnknown = 0;
StatusUpdate = 1;
NonAuthoritativeReplicaInfo = 2;
ScalingRequest = 4;
}
Type type = 7;
string serverName = 1;
repeated ServerReplicaResources resources = 2;
int32 expectedReplicas = 3;
Expand Down Expand Up @@ -186,7 +198,9 @@ message ModelStatusRequest {

message ServerNotifyRequest {
string name = 1;
int32 expectedReplicas = 2;
uint32 expectedReplicas = 2;
uint32 minReplicas = 5;
uint32 maxReplicas = 6;
bool shared = 3;
optional KubernetesMeta kubernetesMeta = 4;
}
Expand Down
33 changes: 6 additions & 27 deletions operator/apis/mlops/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,34 +190,13 @@ func (m Model) AsSchedulerModel() (*scheduler.Model, error) {
md.ModelSpec.Requirements = append(md.ModelSpec.Requirements, *m.Spec.ModelType)
}
// Set Replicas
if m.Spec.Replicas != nil {
md.DeploymentSpec.Replicas = uint32(*m.Spec.Replicas)
} else {
if m.Spec.MinReplicas != nil {
// set replicas to the min replicas if not set
md.DeploymentSpec.Replicas = uint32(*m.Spec.MinReplicas)
} else {
md.DeploymentSpec.Replicas = 1
}
}

if m.Spec.MinReplicas != nil {
md.DeploymentSpec.MinReplicas = uint32(*m.Spec.MinReplicas)
if md.DeploymentSpec.Replicas < md.DeploymentSpec.MinReplicas {
return nil, fmt.Errorf("Number of replicas %d should be >= min replicas %d", md.DeploymentSpec.Replicas, md.DeploymentSpec.MinReplicas)
}
} else {
md.DeploymentSpec.MinReplicas = 0
}

if m.Spec.MaxReplicas != nil {
md.DeploymentSpec.MaxReplicas = uint32(*m.Spec.MaxReplicas)
if md.DeploymentSpec.Replicas > md.DeploymentSpec.MaxReplicas {
return nil, fmt.Errorf("Number of replicas %d should be <= max replicas %d", md.DeploymentSpec.Replicas, md.DeploymentSpec.MaxReplicas)
}
} else {
md.DeploymentSpec.MaxReplicas = 0
scalingSpec, err := GetValidatedScalingSpec(m.Spec.Replicas, m.Spec.MinReplicas, m.Spec.MaxReplicas)
if err != nil {
return nil, err
}
md.DeploymentSpec.Replicas = scalingSpec.Replicas
md.DeploymentSpec.MinReplicas = scalingSpec.MinReplicas
md.DeploymentSpec.MaxReplicas = scalingSpec.MaxReplicas

// Set memory bytes
if m.Spec.Memory != nil {
Expand Down
54 changes: 54 additions & 0 deletions operator/apis/mlops/v1alpha1/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
Copyright (c) 2024 Seldon Technologies Ltd.
Use of this software is governed by
(1) the license included in the LICENSE file or
(2) if the license included in the LICENSE file is the Business Source License 1.1,
the Change License after the Change Date as each is defined in accordance with the LICENSE file.
*/

package v1alpha1

import "fmt"

type ValidatedScalingSpec struct {
Replicas uint32
MinReplicas uint32
MaxReplicas uint32
}


func GetValidatedScalingSpec(replicas *int32, minReplicas *int32, maxReplicas *int32) (*ValidatedScalingSpec, error) {
var validatedSpec ValidatedScalingSpec

if replicas != nil && *replicas > 0 {
validatedSpec.Replicas = uint32(*replicas)
} else {
if minReplicas != nil && *minReplicas > 0 {
// set replicas to the min replicas when replicas is not set explicitly
validatedSpec.Replicas = uint32(*minReplicas)
} else {
validatedSpec.Replicas = 1
}
}

if minReplicas != nil && *minReplicas > 0 {
validatedSpec.MinReplicas = uint32(*minReplicas)
if validatedSpec.Replicas < validatedSpec.MinReplicas {
return nil, fmt.Errorf("number of replicas %d must be >= min replicas %d", validatedSpec.Replicas, validatedSpec.MinReplicas)
}
} else {
validatedSpec.MinReplicas = 0
}

if maxReplicas != nil && *maxReplicas > 0 {
validatedSpec.MaxReplicas = uint32(*maxReplicas)
if validatedSpec.Replicas > validatedSpec.MaxReplicas {
return nil, fmt.Errorf("number of replicas %d must be <= min replicas %d", validatedSpec.Replicas, validatedSpec.MaxReplicas)
}
} else {
validatedSpec.MaxReplicas = 0
}

return &validatedSpec, nil
}
80 changes: 70 additions & 10 deletions operator/scheduler/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,31 @@ func (s *SchedulerClient) ServerNotify(ctx context.Context, server *v1alpha1.Ser
}
grpcClient := scheduler.NewSchedulerClient(conn)

var replicas int32
var scalingSpec *v1alpha1.ValidatedScalingSpec
if !server.ObjectMeta.DeletionTimestamp.IsZero() {
replicas = 0
} else if server.Spec.Replicas != nil {
replicas = *server.Spec.Replicas
scalingSpec = &v1alpha1.ValidatedScalingSpec{
Replicas: 0,
MinReplicas: 0,
MaxReplicas: 0,
}
} else {
replicas = 1
scalingSpec, err = v1alpha1.GetValidatedScalingSpec(server.Spec.Replicas, server.Spec.MinReplicas, server.Spec.MaxReplicas)
if err != nil {
return err
}
}

request := &scheduler.ServerNotifyRequest{
Name: server.GetName(),
ExpectedReplicas: replicas,
ExpectedReplicas: scalingSpec.Replicas,
MinReplicas: scalingSpec.MinReplicas,
MaxReplicas: scalingSpec.MaxReplicas,
KubernetesMeta: &scheduler.KubernetesMeta{
Namespace: server.GetNamespace(),
Generation: server.GetGeneration(),
},
}
logger.Info("Notify server", "name", server.GetName(), "namespace", server.GetNamespace(), "replicas", replicas)
logger.Info("Notify server", "name", server.GetName(), "namespace", server.GetNamespace(), "replicas", scalingSpec.Replicas)
_, err = grpcClient.ServerNotify(
ctx,
request,
Expand All @@ -75,6 +82,10 @@ func (s *SchedulerClient) SubscribeServerEvents(ctx context.Context, grpcClient
if err != nil {
return err
}

// on reconnects send all k8s Server data
go handleExistingServers(ctx, namespace, s, grpcClient)

for {
event, err := stream.Recv()
if err != nil {
Expand Down Expand Up @@ -108,9 +119,29 @@ func (s *SchedulerClient) SubscribeServerEvents(ctx context.Context, grpcClient
logger.Info("Ignoring event for old generation", "currentGeneration", server.Generation, "eventGeneration", event.GetKubernetesMeta().Generation, "server", event.ServerName)
return nil
}
// Handle status update
server.Status.LoadedModelReplicas = event.NumLoadedModelReplicas
return s.updateServerStatus(server)

// The types of updates we may get from the scheduler are:
// 1. Status updates
// 2. Requests for changing the number of server replicas
// 3. Updates containing non-authoritative replica info, because the scheduler is in a
// discovery phase (just starting up, after a restart)
//
// At the moment, the scheduler doesn't send multiple types of updates in a single event;
switch event.GetType() {
case scheduler.ServerStatusResponse_StatusUpdate:
return s.applyStatusUpdates(ctx, server, event)
case scheduler.ServerStatusResponse_ScalingRequest:
if event.ExpectedReplicas != event.AvailableReplicas {
return s.applyReplicaUpdates(ctx, server, event)
} else {
return nil
}
case scheduler.ServerStatusResponse_NonAuthoritativeReplicaInfo:
// skip updating replica info, only update status
return s.updateServerStatus(server)
default: // we ignore unknown event types
return nil
}
})
if retryErr != nil {
logger.Error(err, "Failed to update status", "model", event.ServerName)
Expand All @@ -128,3 +159,32 @@ func (s *SchedulerClient) updateServerStatus(server *v1alpha1.Server) error {
}
return nil
}

// when need to notify the scheduler about existing Server configuration
func handleExistingServers(
ctx context.Context, namespace string, s *SchedulerClient, grpcClient scheduler.SchedulerClient) {
serverList := &v1alpha1.ServerList{}
// Get all servers in the namespace
err := s.List(
ctx,
serverList,
client.InNamespace(namespace),
)
if err != nil {
return
}

for _, server := range serverList.Items {
// servers that are not in the process of being deleted has DeletionTimestamp as zero
if server.ObjectMeta.DeletionTimestamp.IsZero() {
s.logger.V(1).Info("Calling NotifyServer (on reconnect)", "server", server.Name)
if err := s.ServerNotify(ctx, &server); err != nil {
s.logger.Error(err, "Failed to notify scheduler about initial Server parameters", "server", server.Name)
} else {
s.logger.V(1).Info("Load model called successfully", "server", server.Name)
}
} else {
s.logger.V(1).Info("Server being deleted, not notifying", "server", server.Name)
}
}
}
8 changes: 7 additions & 1 deletion scheduler/pkg/coordinator/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

const (
topicModelEvents = "model.event"
topicServerEvents = "server.event"
topicExperimentEvents = "experiment.event"
topicPipelineEvents = "pipeline.event"
)
Expand All @@ -39,6 +40,7 @@ type EventHub struct {
bus *busV3.Bus
logger log.FieldLogger
modelEventHandlerChannels []chan ModelEventMsg
serverEventHandlerChannels []chan ServerEventMsg
experimentEventHandlerChannels []chan ExperimentEventMsg
pipelineEventHandlerChannels []chan PipelineEventMsg
lock sync.RWMutex
Expand All @@ -59,7 +61,7 @@ func NewEventHub(l log.FieldLogger) (*EventHub, error) {
bus: bus,
}

hub.bus.RegisterTopics(topicModelEvents, topicExperimentEvents, topicPipelineEvents)
hub.bus.RegisterTopics(topicModelEvents, topicServerEvents, topicExperimentEvents, topicPipelineEvents)

return &hub, nil
}
Expand All @@ -74,6 +76,10 @@ func (h *EventHub) Close() {
close(c)
}

for _, c := range h.serverEventHandlerChannels {
close(c)
}

for _, c := range h.experimentEventHandlerChannels {
close(c)
}
Expand Down
4 changes: 2 additions & 2 deletions scheduler/pkg/coordinator/hub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ func TestNewEventHub(t *testing.T) {

tests := []test{
{
name: "Should register two topics",
expectedTopics: []string{topicModelEvents, topicExperimentEvents, topicPipelineEvents},
name: "Should register four topics",
expectedTopics: []string{topicModelEvents, topicServerEvents, topicExperimentEvents, topicPipelineEvents},
},
}

Expand Down
100 changes: 100 additions & 0 deletions scheduler/pkg/coordinator/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
Copyright (c) 2024 Seldon Technologies Ltd.
Use of this software is governed by
(1) the license included in the LICENSE file or
(2) if the license included in the LICENSE file is the Business Source License 1.1,
the Change License after the Change Date as each is defined in accordance with the LICENSE file.
*/

package coordinator

import (
"context"
"reflect"

busV3 "github.com/mustafaturan/bus/v3"
log "github.com/sirupsen/logrus"
)

func (h *EventHub) RegisterServerEventHandler(
name string,
queueSize int,
logger log.FieldLogger,
handle func(event ServerEventMsg),
) {
events := make(chan ServerEventMsg, queueSize)
h.addServerEventHandlerChannel(events)

go func() {
for e := range events {
handle(e)
}
}()

handler := h.newServerEventHandler(logger, events, handle)
h.bus.RegisterHandler(name, handler)
}

func (h *EventHub) newServerEventHandler(
logger log.FieldLogger,
events chan ServerEventMsg,
_ func(event ServerEventMsg),
) busV3.Handler {
handleServerEventMessage := func(_ context.Context, e busV3.Event) {
l := logger.WithField("func", "handleServerEventMessage")
l.Debugf("Received event on %s from %s (ID: %s, TxID: %s)", e.Topic, e.Source, e.ID, e.TxID)

me, ok := e.Data.(ServerEventMsg)
if !ok {
l.Warnf(
"Event (ID %s, TxID %s) on topic %s from %s is not a ServerEventMsg: %s",
e.ID,
e.TxID,
e.Topic,
e.Source,
reflect.TypeOf(e.Data).String(),
)
return
}

h.lock.RLock()
if h.closed {
return
}
// Propagate the busV3.Event source to the ServerEventMsg
// This is useful for logging, but also in case we want to distinguish
// the action to take based on where the event came from.
me.Source = e.Source
events <- me
h.lock.RUnlock()
}

return busV3.Handler{
Matcher: topicServerEvents,
Handle: handleServerEventMessage,
}
}

func (h *EventHub) addServerEventHandlerChannel(c chan ServerEventMsg) {
h.lock.Lock()
defer h.lock.Unlock()

h.serverEventHandlerChannels = append(h.serverEventHandlerChannels, c)
}

func (h *EventHub) PublishServerEvent(source string, event ServerEventMsg) {
err := h.bus.EmitWithOpts(
context.Background(),
topicServerEvents,
event,
busV3.WithSource(source),
)
if err != nil {
h.logger.WithError(err).Errorf(
"unable to publish server event message from %s to %s",
source,
topicServerEvents,
)
}
}
Loading

0 comments on commit 29f3945

Please sign in to comment.