Skip to content

Commit

Permalink
main.go: Fix style (#55)
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <[email protected]>
  • Loading branch information
gaocegege authored Apr 19, 2018
1 parent 56e143a commit eb61c8e
Showing 1 changed file with 33 additions and 61 deletions.
94 changes: 33 additions & 61 deletions manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,18 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"

batchv1 "k8s.io/api/batch/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
)

const (
k8s_namespace = "katib"
namespace = "katib"
port = "0.0.0.0:6789"
defaultEarlyStopInterval = 60
defaultSaveInterval = 30
)

var init_db = flag.Bool("init", false, "Initialize DB")
var worker = flag.String("w", "kubernetes", "Worker Typw")
var workerType = flag.String("w", "kubernetes", "Worker Type")
var dbIf kdb.VizierDBInterface

type studyCh struct {
Expand All @@ -49,13 +45,13 @@ type server struct {
StudyChList map[string]studyCh
}

func (s *server) saveCompletedModels(studyId string, conf *pb.StudyConfig) error {
func (s *server) saveCompletedModels(studyID string, conf *pb.StudyConfig) error {
ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name})
if err != nil {
log.Printf("GetSavedModels Err %v", err)
return err
}
ts, err := dbIf.GetTrialList(studyId)
ts, err := dbIf.GetTrialList(studyID)
if err != nil {
log.Printf("GetTrials Err %v", err)
return err
Expand Down Expand Up @@ -99,9 +95,9 @@ func (s *server) saveCompletedModels(studyId string, conf *pb.StudyConfig) error
return nil
}

func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh studyCh) error {
defer delete(s.StudyChList, study_id)
defer s.wIF.CleanWorkers(study_id)
func (s *server) trialIteration(conf *pb.StudyConfig, studyID string, sCh studyCh) error {
defer delete(s.StudyChList, studyID)
defer s.wIF.CleanWorkers(studyID)
tm := time.NewTimer(1 * time.Second)
ei := 0
var err error
Expand All @@ -118,41 +114,41 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
}
estm := time.NewTimer(time.Duration(ei) * time.Second)
strtm := time.NewTimer(defaultSaveInterval * time.Second)
log.Printf("Study %v start.", study_id)
log.Printf("Study %v start.", studyID)
log.Printf("Study conf %v", conf)
for {
select {
case <-tm.C:
if conf.SuggestAlgorithm != "" {
err := s.wIF.CheckRunningTrials(study_id, conf.ObjectiveValueName)
err := s.wIF.CheckRunningTrials(studyID, conf.ObjectiveValueName)
if err != nil {
return err
}
r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: study_id, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf})
r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: studyID, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf})
if err != nil {
log.Printf("SuggestTrials failed %v", err)
return err
}
if r.Completed {
log.Printf("Study %v completed.", study_id)
return s.saveCompletedModels(study_id, conf)
log.Printf("Study %v completed.", studyID)
return s.saveCompletedModels(studyID, conf)
} else if len(r.Trials) > 0 {
for _, trial := range r.Trials {
trial.Status = pb.TrialState_PENDING
trial.StudyId = study_id
trial.StudyId = studyID
err = dbIf.CreateTrial(trial)
if err != nil {
log.Printf("CreateTrial failed %v", err)
return err
}
}
err = s.wIF.SpawnWorkers(r.Trials, study_id)
err = s.wIF.SpawnWorkers(r.Trials, studyID)
if err != nil {
log.Printf("SpawnWorkers failed %v", err)
return err
}
for _, t := range r.Trials {
err = tbif.SpawnTensorBoard(study_id, t.TrialId, k8s_namespace, conf.Mount)
err = tbif.SpawnTensorBoard(studyID, t.TrialId, namespace, conf.Mount)
if err != nil {
log.Printf("SpawnTB failed %v", err)
return err
Expand All @@ -162,27 +158,27 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
tm.Reset(1 * time.Second)
}
case <-strtm.C:
s.saveCompletedModels(study_id, conf)
s.saveCompletedModels(studyID, conf)
strtm.Reset(defaultSaveInterval * time.Second)

case <-estm.C:
ret, err := s.EarlyStopping(context.Background(), &pb.EarlyStoppingRequest{StudyId: study_id, EarlyStoppingAlgorithm: conf.EarlyStoppingAlgorithm})
ret, err := s.EarlyStopping(context.Background(), &pb.EarlyStoppingRequest{StudyId: studyID, EarlyStoppingAlgorithm: conf.EarlyStoppingAlgorithm})
if err != nil {
log.Printf("Early Stopping Error: %v", err)
} else {
if len(ret.Trials) > 0 {
for _, t := range ret.Trials {
s.CompleteTrial(context.Background(), &pb.CompleteTrialRequest{StudyId: study_id, TrialId: t.TrialId, IsComplete: false})
s.CompleteTrial(context.Background(), &pb.CompleteTrialRequest{StudyId: studyID, TrialId: t.TrialId, IsComplete: false})
}
}
}
estm.Reset(time.Duration(ei) * time.Second)
case <-sCh.stopCh:
log.Printf("Study %v is stopped.", study_id)
for _, t := range s.wIF.GetRunningTrials(study_id) {
log.Printf("Study %v is stopped.", studyID)
for _, t := range s.wIF.GetRunningTrials(studyID) {
t.Status = pb.TrialState_KILLED
}
return s.saveCompletedModels(study_id, conf)
return s.saveCompletedModels(studyID, conf)
case m := <-sCh.addMetricsCh:
conf.Metrics = append(conf.Metrics, m)
}
Expand All @@ -199,12 +195,12 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p
return &pb.CreateStudyReply{}, errors.New("Objective_Value_Name is required.")
}

study_id, err := dbIf.CreateStudy(in.StudyConfig)
studyID, err := dbIf.CreateStudy(in.StudyConfig)
if in.StudyConfig.SuggestAlgorithm != "" {
_, err = s.InitializeSuggestService(
ctx,
&pb.InitializeSuggestServiceRequest{
StudyId: study_id,
StudyId: studyID,
SuggestAlgorithm: in.StudyConfig.SuggestAlgorithm,
SuggestionParameters: in.StudyConfig.SuggestionParameters,
Configs: in.StudyConfig,
Expand All @@ -227,7 +223,7 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p
_, err = c.SetEarlyStoppingParameter(
context.Background(),
&pb.SetEarlyStoppingParameterRequest{
StudyId: study_id,
StudyId: studyID,
EarlyStoppingParameters: in.StudyConfig.EarlyStoppingParameters,
},
)
Expand All @@ -240,9 +236,9 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p
if err != nil {
return &pb.CreateStudyReply{}, err
}
go s.trialIteration(in.StudyConfig, study_id, sCh)
s.StudyChList[study_id] = sCh
return &pb.CreateStudyReply{StudyId: study_id}, nil
go s.trialIteration(in.StudyConfig, studyID, sCh)
s.StudyChList[studyID] = sCh
return &pb.CreateStudyReply{StudyId: studyID}, nil
}

func (s *server) StopStudy(ctx context.Context, in *pb.StopStudyRequest) (*pb.StopStudyReply, error) {
Expand All @@ -254,30 +250,6 @@ func (s *server) StopStudy(ctx context.Context, in *pb.StopStudyRequest) (*pb.St
return &pb.StopStudyReply{}, nil
}

func spawn_worker(study_task string, params string) error {
config, err := rest.InClusterConfig()
if err != nil {
return err
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return err
}
template, err := clientset.CoreV1().PodTemplates(k8s_namespace).Get(study_task, metav1.GetOptions{})
if err != nil {
return err
}

_, err = clientset.BatchV1().Jobs(k8s_namespace).Create(&batchv1.Job{
Spec: batchv1.JobSpec{
Template: template.Template,
},
})

// TODO: Update worker status
return err
}

func (s *server) GetStudies(ctx context.Context, in *pb.GetStudiesRequest) (*pb.GetStudiesReply, error) {
ss := make([]*pb.StudyInfo, len(s.StudyChList))
i := 0
Expand Down Expand Up @@ -312,7 +284,7 @@ func (s *server) InitializeSuggestService(ctx context.Context, in *pb.Initialize
}

func (s *server) SuggestTrials(ctx context.Context, in *pb.SuggestTrialsRequest) (*pb.SuggestTrialsReply, error) {
var suggest_algo string
var suggestAlgorithm string

// TODO: only a few columns are needed but GetStudyConfig does a full retrieval
study, err := dbIf.GetStudyConfig(in.StudyId)
Expand All @@ -321,14 +293,14 @@ func (s *server) SuggestTrials(ctx context.Context, in *pb.SuggestTrialsRequest)
}

if in.SuggestAlgorithm != "" {
suggest_algo = in.SuggestAlgorithm
suggestAlgorithm = in.SuggestAlgorithm
} else if study.SuggestAlgorithm != "" {
suggest_algo = study.SuggestAlgorithm
suggestAlgorithm = study.SuggestAlgorithm
} else {
return &pb.SuggestTrialsReply{Completed: false}, errors.New("No suggest algorithm specified")
}

conn, err := grpc.Dial("vizier-suggestion-"+suggest_algo+":6789", grpc.WithInsecure())
conn, err := grpc.Dial("vizier-suggestion-"+suggestAlgorithm+":6789", grpc.WithInsecure())
if err != nil {
return &pb.SuggestTrialsReply{Completed: false}, err
}
Expand Down Expand Up @@ -418,7 +390,7 @@ func main() {
}
size := 1<<31 - 1
s := grpc.NewServer(grpc.MaxRecvMsgSize(size), grpc.MaxSendMsgSize(size))
switch *worker {
switch *workerType {
case "kubernetes":
log.Printf("Worker: kubernetes\n")
kc, err := clientcmd.BuildConfigFromFlags("", "/conf/kubeconfig")
Expand All @@ -432,7 +404,7 @@ func main() {
pb.RegisterManagerServer(s, &server{wIF: k8swif.NewKubernetesWorkerInterface(clientset, dbIf), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)})
case "dlk":
log.Printf("Worker: dlk\n")
pb.RegisterManagerServer(s, &server{wIF: dlkwif.NewDlkWorkerInterface("http://dlk-manager:1323", k8s_namespace), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)})
pb.RegisterManagerServer(s, &server{wIF: dlkwif.NewDlkWorkerInterface("http://dlk-manager:1323", namespace), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)})
case "nv-docker":
log.Printf("Worker: nv-docker\n")
pb.RegisterManagerServer(s, &server{wIF: nvdwif.NewNvDockerWorkerInterface(), msIf: modelstore.NewModelDB("modeldb-backend", "6543"), StudyChList: make(map[string]studyCh)})
Expand Down

0 comments on commit eb61c8e

Please sign in to comment.