diff --git a/manager/main.go b/manager/main.go index d602c99fb09..f7c7adea028 100644 --- a/manager/main.go +++ b/manager/main.go @@ -21,22 +21,18 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/reflection" - batchv1 "k8s.io/api/batch/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" ) const ( - k8s_namespace = "katib" + namespace = "katib" port = "0.0.0.0:6789" defaultEarlyStopInterval = 60 defaultSaveInterval = 30 ) -var init_db = flag.Bool("init", false, "Initialize DB") -var worker = flag.String("w", "kubernetes", "Worker Typw") +var workerType = flag.String("w", "kubernetes", "Worker Type") var dbIf kdb.VizierDBInterface type studyCh struct { @@ -49,13 +45,13 @@ type server struct { StudyChList map[string]studyCh } -func (s *server) saveCompletedModels(studyId string, conf *pb.StudyConfig) error { +func (s *server) saveCompletedModels(studyID string, conf *pb.StudyConfig) error { ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name}) if err != nil { log.Printf("GetSavedModels Err %v", err) return err } - ts, err := dbIf.GetTrialList(studyId) + ts, err := dbIf.GetTrialList(studyID) if err != nil { log.Printf("GetTrials Err %v", err) return err @@ -99,9 +95,9 @@ func (s *server) saveCompletedModels(studyId string, conf *pb.StudyConfig) error return nil } -func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh studyCh) error { - defer delete(s.StudyChList, study_id) - defer s.wIF.CleanWorkers(study_id) +func (s *server) trialIteration(conf *pb.StudyConfig, studyID string, sCh studyCh) error { + defer delete(s.StudyChList, studyID) + defer s.wIF.CleanWorkers(studyID) tm := time.NewTimer(1 * time.Second) ei := 0 var err error @@ -118,41 +114,41 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study } estm := time.NewTimer(time.Duration(ei) * time.Second) strtm := time.NewTimer(defaultSaveInterval * time.Second) - log.Printf("Study %v start.", study_id) + log.Printf("Study %v start.", studyID) log.Printf("Study conf %v", conf) for { select { case <-tm.C: if conf.SuggestAlgorithm != "" { - err := s.wIF.CheckRunningTrials(study_id, conf.ObjectiveValueName) + err := s.wIF.CheckRunningTrials(studyID, conf.ObjectiveValueName) if err != nil { return err } - r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: study_id, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf}) + r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: studyID, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf}) if err != nil { log.Printf("SuggestTrials failed %v", err) return err } if r.Completed { - log.Printf("Study %v completed.", study_id) - return s.saveCompletedModels(study_id, conf) + log.Printf("Study %v completed.", studyID) + return s.saveCompletedModels(studyID, conf) } else if len(r.Trials) > 0 { for _, trial := range r.Trials { trial.Status = pb.TrialState_PENDING - trial.StudyId = study_id + trial.StudyId = studyID err = dbIf.CreateTrial(trial) if err != nil { log.Printf("CreateTrial failed %v", err) return err } } - err = s.wIF.SpawnWorkers(r.Trials, study_id) + err = s.wIF.SpawnWorkers(r.Trials, studyID) if err != nil { log.Printf("SpawnWorkers failed %v", err) return err } for _, t := range r.Trials { - err = tbif.SpawnTensorBoard(study_id, t.TrialId, k8s_namespace, conf.Mount) + err = tbif.SpawnTensorBoard(studyID, t.TrialId, namespace, conf.Mount) if err != nil { log.Printf("SpawnTB failed %v", err) return err @@ -162,27 +158,27 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study tm.Reset(1 * time.Second) } case <-strtm.C: - s.saveCompletedModels(study_id, conf) + s.saveCompletedModels(studyID, conf) strtm.Reset(defaultSaveInterval * time.Second) case <-estm.C: - ret, err := s.EarlyStopping(context.Background(), &pb.EarlyStoppingRequest{StudyId: study_id, EarlyStoppingAlgorithm: conf.EarlyStoppingAlgorithm}) + ret, err := s.EarlyStopping(context.Background(), &pb.EarlyStoppingRequest{StudyId: studyID, EarlyStoppingAlgorithm: conf.EarlyStoppingAlgorithm}) if err != nil { log.Printf("Early Stopping Error: %v", err) } else { if len(ret.Trials) > 0 { for _, t := range ret.Trials { - s.CompleteTrial(context.Background(), &pb.CompleteTrialRequest{StudyId: study_id, TrialId: t.TrialId, IsComplete: false}) + s.CompleteTrial(context.Background(), &pb.CompleteTrialRequest{StudyId: studyID, TrialId: t.TrialId, IsComplete: false}) } } } estm.Reset(time.Duration(ei) * time.Second) case <-sCh.stopCh: - log.Printf("Study %v is stopped.", study_id) - for _, t := range s.wIF.GetRunningTrials(study_id) { + log.Printf("Study %v is stopped.", studyID) + for _, t := range s.wIF.GetRunningTrials(studyID) { t.Status = pb.TrialState_KILLED } - return s.saveCompletedModels(study_id, conf) + return s.saveCompletedModels(studyID, conf) case m := <-sCh.addMetricsCh: conf.Metrics = append(conf.Metrics, m) } @@ -199,12 +195,12 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p return &pb.CreateStudyReply{}, errors.New("Objective_Value_Name is required.") } - study_id, err := dbIf.CreateStudy(in.StudyConfig) + studyID, err := dbIf.CreateStudy(in.StudyConfig) if in.StudyConfig.SuggestAlgorithm != "" { _, err = s.InitializeSuggestService( ctx, &pb.InitializeSuggestServiceRequest{ - StudyId: study_id, + StudyId: studyID, SuggestAlgorithm: in.StudyConfig.SuggestAlgorithm, SuggestionParameters: in.StudyConfig.SuggestionParameters, Configs: in.StudyConfig, @@ -227,7 +223,7 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p _, err = c.SetEarlyStoppingParameter( context.Background(), &pb.SetEarlyStoppingParameterRequest{ - StudyId: study_id, + StudyId: studyID, EarlyStoppingParameters: in.StudyConfig.EarlyStoppingParameters, }, ) @@ -240,9 +236,9 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p if err != nil { return &pb.CreateStudyReply{}, err } - go s.trialIteration(in.StudyConfig, study_id, sCh) - s.StudyChList[study_id] = sCh - return &pb.CreateStudyReply{StudyId: study_id}, nil + go s.trialIteration(in.StudyConfig, studyID, sCh) + s.StudyChList[studyID] = sCh + return &pb.CreateStudyReply{StudyId: studyID}, nil } func (s *server) StopStudy(ctx context.Context, in *pb.StopStudyRequest) (*pb.StopStudyReply, error) { @@ -254,30 +250,6 @@ func (s *server) StopStudy(ctx context.Context, in *pb.StopStudyRequest) (*pb.St return &pb.StopStudyReply{}, nil } -func spawn_worker(study_task string, params string) error { - config, err := rest.InClusterConfig() - if err != nil { - return err - } - clientset, err := kubernetes.NewForConfig(config) - if err != nil { - return err - } - template, err := clientset.CoreV1().PodTemplates(k8s_namespace).Get(study_task, metav1.GetOptions{}) - if err != nil { - return err - } - - _, err = clientset.BatchV1().Jobs(k8s_namespace).Create(&batchv1.Job{ - Spec: batchv1.JobSpec{ - Template: template.Template, - }, - }) - - // TODO: Update worker status - return err -} - func (s *server) GetStudies(ctx context.Context, in *pb.GetStudiesRequest) (*pb.GetStudiesReply, error) { ss := make([]*pb.StudyInfo, len(s.StudyChList)) i := 0 @@ -312,7 +284,7 @@ func (s *server) InitializeSuggestService(ctx context.Context, in *pb.Initialize } func (s *server) SuggestTrials(ctx context.Context, in *pb.SuggestTrialsRequest) (*pb.SuggestTrialsReply, error) { - var suggest_algo string + var suggestAlgorithm string // TODO: only a few columns are needed but GetStudyConfig does a full retrieval study, err := dbIf.GetStudyConfig(in.StudyId) @@ -321,14 +293,14 @@ func (s *server) SuggestTrials(ctx context.Context, in *pb.SuggestTrialsRequest) } if in.SuggestAlgorithm != "" { - suggest_algo = in.SuggestAlgorithm + suggestAlgorithm = in.SuggestAlgorithm } else if study.SuggestAlgorithm != "" { - suggest_algo = study.SuggestAlgorithm + suggestAlgorithm = study.SuggestAlgorithm } else { return &pb.SuggestTrialsReply{Completed: false}, errors.New("No suggest algorithm specified") } - conn, err := grpc.Dial("vizier-suggestion-"+suggest_algo+":6789", grpc.WithInsecure()) + conn, err := grpc.Dial("vizier-suggestion-"+suggestAlgorithm+":6789", grpc.WithInsecure()) if err != nil { return &pb.SuggestTrialsReply{Completed: false}, err } @@ -418,7 +390,7 @@ func main() { } size := 1<<31 - 1 s := grpc.NewServer(grpc.MaxRecvMsgSize(size), grpc.MaxSendMsgSize(size)) - switch *worker { + switch *workerType { case "kubernetes": log.Printf("Worker: kubernetes\n") kc, err := clientcmd.BuildConfigFromFlags("", "/conf/kubeconfig") @@ -432,7 +404,7 @@ func main() { pb.RegisterManagerServer(s, &server{wIF: k8swif.NewKubernetesWorkerInterface(clientset, dbIf), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)}) case "dlk": log.Printf("Worker: dlk\n") - pb.RegisterManagerServer(s, &server{wIF: dlkwif.NewDlkWorkerInterface("http://dlk-manager:1323", k8s_namespace), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)}) + pb.RegisterManagerServer(s, &server{wIF: dlkwif.NewDlkWorkerInterface("http://dlk-manager:1323", namespace), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)}) case "nv-docker": log.Printf("Worker: nv-docker\n") pb.RegisterManagerServer(s, &server{wIF: nvdwif.NewNvDockerWorkerInterface(), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)})