diff --git a/master/internal/api_experiment.go b/master/internal/api_experiment.go index 54d43b98e9e..f0798c45cfd 100644 --- a/master/internal/api_experiment.go +++ b/master/internal/api_experiment.go @@ -1251,11 +1251,6 @@ func (a *apiServer) PatchExperiment( } } - // `patch` represents the allowed mutations that can be performed on an experiment, in JSON - if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil { - return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID) - } - if newResources != nil { e, ok := experiment.ExperimentRegistry.Load(int(exp.Id)) if !ok { @@ -1264,6 +1259,15 @@ func (a *apiServer) PatchExperiment( if newResources.MaxSlots != nil { msg := sproto.SetGroupMaxSlots{MaxSlots: ptrs.Ptr(int(*newResources.MaxSlots))} + w, err := getWorkspaceByConfig(activeConfig) + if err != nil { + return nil, status.Errorf(codes.Internal, err.Error()) + } + + err = configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, err.Error()) + } e.SetGroupMaxSlots(msg) } if newResources.Weight != nil { @@ -1280,6 +1284,11 @@ func (a *apiServer) PatchExperiment( } } + // `patch` represents the allowed mutations that can be performed on an experiment, in JSON + if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil { + return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID) + } + if newCheckpointStorage != nil { checkpoints, err := experiment.ExperimentCheckpointsToGCRaw( ctx, diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index 11cf96ff24b..21c82deb979 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -150,6 +150,9 @@ func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, f if err == nil && len(globalBytes) == 0 { confBytes = wkspBytes } + if len(globalBytes) > 0 || len(wkspBytes) > 0 { + err = nil + } return err }) if err == sql.ErrNoRows || len(confBytes) == 0 { diff --git a/master/internal/configpolicy/utils.go b/master/internal/configpolicy/utils.go index 7f08edd51ae..cbe37890e03 100644 --- a/master/internal/configpolicy/utils.go +++ b/master/internal/configpolicy/utils.go @@ -31,6 +31,8 @@ const ( // SlotsReqTooHighErr is the error reported when the requested slots violates the max slots // constraint. SlotsReqTooHighErr = "requested slots is violates max slots constraint" + // SlotsAlreadySetErr is the error reported when slots are already set in an invariant config. + SlotsAlreadySetErr = "max slots is already set in an invariant config policy" ) // ConfigPolicyWarning logs a warning for the configuration policy component. @@ -307,26 +309,26 @@ func configPolicyOverlap(config1, config2 interface{}) { // enforced max slots for the workspace if that's set as an invariant config, and returns the // requested max slots otherwise. Returns an error when max slots is not set as an invariant config // and the requested max slots violates the constriant. -func CanSetMaxSlots(slotsReq *int, wkspID int) (*int, error) { +func CanSetMaxSlots(slotsReq *int, wkspID int) error { if slotsReq == nil { - return slotsReq, nil + return nil } enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID, "invariant_config", "'resources' -> 'max_slots'", model.ExperimentType) if err != nil { - return nil, err + return err } - if enforcedMaxSlots != nil { - return enforcedMaxSlots, nil + if enforcedMaxSlots != nil && *slotsReq != *enforcedMaxSlots { + return fmt.Errorf(SlotsAlreadySetErr) } maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID, "constraints", "'resources' -> 'max_slots'", model.ExperimentType) if err != nil { - return nil, err + return err } var canSetReqSlots bool @@ -334,8 +336,8 @@ func CanSetMaxSlots(slotsReq *int, wkspID int) (*int, error) { canSetReqSlots = true } if !canSetReqSlots { - return nil, fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit) + return fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit) } - return slotsReq, nil + return nil } diff --git a/master/internal/configpolicy/utils_test.go b/master/internal/configpolicy/utils_test.go index dd5f80820d2..c3714d504f9 100644 --- a/master/internal/configpolicy/utils_test.go +++ b/master/internal/configpolicy/utils_test.go @@ -624,9 +624,8 @@ func TestCanSetMaxSlots(t *testing.T) { ctx := context.Background() w := createWorkspaceWithUser(ctx, t, user.ID) t.Run("nil slots request", func(t *testing.T) { - slots, err := CanSetMaxSlots(nil, w.ID) + err := CanSetMaxSlots(nil, w.ID) require.NoError(t, err) - require.Nil(t, slots) }) err := SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{ @@ -639,29 +638,18 @@ func TestCanSetMaxSlots(t *testing.T) { "max_slots": 13 } } -`), - Constraints: ptrs.Ptr(` -{ - "resources": { - "max_slots": 13 - } -} `), }) require.NoError(t, err) t.Run("slots different than config higher", func(t *testing.T) { - slots, err := CanSetMaxSlots(ptrs.Ptr(15), w.ID) - require.NoError(t, err) - require.NotNil(t, slots) - require.Equal(t, 13, *slots) + err = CanSetMaxSlots(ptrs.Ptr(15), w.ID) + require.ErrorContains(t, err, SlotsAlreadySetErr) }) t.Run("slots different than config lower", func(t *testing.T) { - slots, err := CanSetMaxSlots(ptrs.Ptr(10), w.ID) - require.NoError(t, err) - require.NotNil(t, slots) - require.Equal(t, 13, *slots) + err = CanSetMaxSlots(ptrs.Ptr(10), w.ID) + require.ErrorContains(t, err, SlotsAlreadySetErr) }) t.Run("just constraints slots higher", func(t *testing.T) { @@ -679,9 +667,8 @@ func TestCanSetMaxSlots(t *testing.T) { }) require.NoError(t, err) - slots, err := CanSetMaxSlots(ptrs.Ptr(25), w.ID) + err = CanSetMaxSlots(ptrs.Ptr(25), w.ID) require.ErrorContains(t, err, SlotsReqTooHighErr) - require.Nil(t, slots) }) t.Run("just constraints slots lower", func(t *testing.T) { @@ -699,9 +686,7 @@ func TestCanSetMaxSlots(t *testing.T) { }) require.NoError(t, err) - slots, err := CanSetMaxSlots(ptrs.Ptr(20), w.ID) + err = CanSetMaxSlots(ptrs.Ptr(20), w.ID) require.NoError(t, err) - require.NotNil(t, slots) - require.Equal(t, 20, *slots) }) } diff --git a/master/internal/experiment.go b/master/internal/experiment.go index 01f1c955176..56ec98f0eba 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -418,13 +418,12 @@ func (e *internalExperiment) SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) { return } - slots, err := configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) + err = configpolicy.CanSetMaxSlots(msg.MaxSlots, w.ID) if err != nil { log.Warnf("unable to set max slots: %s", err.Error()) return } - msg.MaxSlots = slots resources := e.activeConfig.Resources() resources.SetMaxSlots(msg.MaxSlots) e.activeConfig.SetResources(resources)