Skip to content

Commit

Permalink
chore: set experiment config later
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 committed Oct 25, 2024
1 parent 782f7a0 commit d7f2de3
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 37 deletions.
19 changes: 14 additions & 5 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1251,11 +1251,6 @@ func (a *apiServer) PatchExperiment(
}
}

// `patch` represents the allowed mutations that can be performed on an experiment, in JSON
if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil {
return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID)
}

if newResources != nil {
e, ok := experiment.ExperimentRegistry.Load(int(exp.Id))
if !ok {
Expand All @@ -1264,6 +1259,15 @@ func (a *apiServer) PatchExperiment(

if newResources.MaxSlots != nil {
msg := sproto.SetGroupMaxSlots{MaxSlots: ptrs.Ptr(int(*newResources.MaxSlots))}
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
}

err = configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}
e.SetGroupMaxSlots(msg)
}
if newResources.Weight != nil {
Expand All @@ -1280,6 +1284,11 @@ func (a *apiServer) PatchExperiment(
}
}

// `patch` represents the allowed mutations that can be performed on an experiment, in JSON
if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil {
return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID)
}

if newCheckpointStorage != nil {
checkpoints, err := experiment.ExperimentCheckpointsToGCRaw(
ctx,
Expand Down
3 changes: 3 additions & 0 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, f
if err == nil && len(globalBytes) == 0 {
confBytes = wkspBytes
}
if len(globalBytes) > 0 || len(wkspBytes) > 0 {
err = nil
}
return err
})
if err == sql.ErrNoRows || len(confBytes) == 0 {
Expand Down
18 changes: 10 additions & 8 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const (
// SlotsReqTooHighErr is the error reported when the requested slots violates the max slots
// constraint.
SlotsReqTooHighErr = "requested slots is violates max slots constraint"
// SlotsAlreadySetErr is the error reported when slots are already set in an invariant config.
SlotsAlreadySetErr = "max slots is already set in an invariant config policy"
)

// ConfigPolicyWarning logs a warning for the configuration policy component.
Expand Down Expand Up @@ -307,35 +309,35 @@ func configPolicyOverlap(config1, config2 interface{}) {
// enforced max slots for the workspace if that's set as an invariant config, and returns the
// requested max slots otherwise. Returns an error when max slots is not set as an invariant config
// and the requested max slots violates the constriant.
func CanSetMaxSlots(slotsReq *int, wkspID int) (*int, error) {
func CanSetMaxSlots(slotsReq *int, wkspID int) error {
if slotsReq == nil {
return slotsReq, nil
return nil
}
enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
return nil, err
return err
}

if enforcedMaxSlots != nil {
return enforcedMaxSlots, nil
if enforcedMaxSlots != nil && *slotsReq != *enforcedMaxSlots {
return fmt.Errorf(SlotsAlreadySetErr)
}

maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
if err != nil {
return nil, err
return err
}

var canSetReqSlots bool
if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit {
canSetReqSlots = true
}
if !canSetReqSlots {
return nil, fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit)
return fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit)
}

return slotsReq, nil
return nil
}
29 changes: 7 additions & 22 deletions master/internal/configpolicy/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,8 @@ func TestCanSetMaxSlots(t *testing.T) {
ctx := context.Background()
w := createWorkspaceWithUser(ctx, t, user.ID)
t.Run("nil slots request", func(t *testing.T) {
slots, err := CanSetMaxSlots(nil, w.ID)
err := CanSetMaxSlots(nil, w.ID)
require.NoError(t, err)
require.Nil(t, slots)
})

err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
Expand All @@ -639,29 +638,18 @@ func TestCanSetMaxSlots(t *testing.T) {
"max_slots": 13
}
}
`),
Constraints: ptrs.Ptr(`
{
"resources": {
"max_slots": 13
}
}
`),
})
require.NoError(t, err)

t.Run("slots different than config higher", func(t *testing.T) {
slots, err := CanSetMaxSlots(ptrs.Ptr(15), w.ID)
require.NoError(t, err)
require.NotNil(t, slots)
require.Equal(t, 13, *slots)
err = CanSetMaxSlots(ptrs.Ptr(15), w.ID)
require.ErrorContains(t, err, SlotsAlreadySetErr)
})

t.Run("slots different than config lower", func(t *testing.T) {
slots, err := CanSetMaxSlots(ptrs.Ptr(10), w.ID)
require.NoError(t, err)
require.NotNil(t, slots)
require.Equal(t, 13, *slots)
err = CanSetMaxSlots(ptrs.Ptr(10), w.ID)
require.ErrorContains(t, err, SlotsAlreadySetErr)
})

t.Run("just constraints slots higher", func(t *testing.T) {
Expand All @@ -679,9 +667,8 @@ func TestCanSetMaxSlots(t *testing.T) {
})
require.NoError(t, err)

slots, err := CanSetMaxSlots(ptrs.Ptr(25), w.ID)
err = CanSetMaxSlots(ptrs.Ptr(25), w.ID)
require.ErrorContains(t, err, SlotsReqTooHighErr)
require.Nil(t, slots)
})

t.Run("just constraints slots lower", func(t *testing.T) {
Expand All @@ -699,9 +686,7 @@ func TestCanSetMaxSlots(t *testing.T) {
})
require.NoError(t, err)

slots, err := CanSetMaxSlots(ptrs.Ptr(20), w.ID)
err = CanSetMaxSlots(ptrs.Ptr(20), w.ID)
require.NoError(t, err)
require.NotNil(t, slots)
require.Equal(t, 20, *slots)
})
}
3 changes: 1 addition & 2 deletions master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,12 @@ func (e *internalExperiment) SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) {
return
}

slots, err := configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID)
err = configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID)
if err != nil {
log.Warnf("unable to set max slots: %s", err.Error())
return
}

msg.MaxSlots = slots
resources := e.activeConfig.Resources()
resources.SetMaxSlots(msg.MaxSlots)
e.activeConfig.SetResources(resources)
Expand Down

0 comments on commit d7f2de3

Please sign in to comment.