diff --git a/examples/gpu-example.yaml b/examples/gpu-example.yaml index 6a8ff967274..9123338b1f7 100644 --- a/examples/gpu-example.yaml +++ b/examples/gpu-example.yaml @@ -33,7 +33,10 @@ spec: - ftrl workerSpec: goTemplate: - templatePath: "/worker-template/gpuWorkerTemplate.yaml" + templateSpec: + configMapName: "worker-template" + templatePath: "gpuWorkerTemplate.yaml" + configMapNamespace: "kubeflow" suggestionSpec: suggestionAlgorithm: "random" requestNumber: 3 diff --git a/examples/grid-example.yaml b/examples/grid-example.yaml index 10cce0273e9..52f55dde399 100644 --- a/examples/grid-example.yaml +++ b/examples/grid-example.yaml @@ -33,7 +33,10 @@ spec: - ftrl workerSpec: goTemplate: - templatePath: "/worker-template/cpuWorkerTemplate.yaml" + templateSpec: + configMapName: "worker-template" + templatePath: "cpuWorkerTemplate.yaml" + configMapNamespace: "kubeflow" suggestionSpec: suggestionAlgorithm: "grid" suggestionParameters: diff --git a/examples/hypb-example.yaml b/examples/hypb-example.yaml index a6fed35e836..1ee81b9c67f 100644 --- a/examples/hypb-example.yaml +++ b/examples/hypb-example.yaml @@ -37,7 +37,10 @@ spec: max: "20" workerSpec: goTemplate: - templatePath: "/worker-template/cpuWorkerTemplate.yaml" + templateSpec: + configMapName: "worker-template" + templatePath: "cpuWorkerTemplate.yaml" + configMapNamespace: "kubeflow" suggestionSpec: suggestionAlgorithm: "hyperband" suggestionParameters: diff --git a/pkg/api/operators/apis/studyjob/v1alpha1/studyjob_types.go b/pkg/api/operators/apis/studyjob/v1alpha1/studyjob_types.go index 1a606baa2a7..44e93115976 100644 --- a/pkg/api/operators/apis/studyjob/v1alpha1/studyjob_types.go +++ b/pkg/api/operators/apis/studyjob/v1alpha1/studyjob_types.go @@ -114,19 +114,25 @@ const ( OptimizationTypeMaximize OptimizationType = "maximize" ) -type GoTemplate struct { +type TemplateSpec struct { + ConfigMapName string `json:"configMapName,omitempty"` + ConfigMapNamespace string `json:"configMapNamespace,omitempty"` TemplatePath string `json:"templatePath,omitempty"` +} + +type GoTemplate struct { + TemplateSpec *TemplateSpec `json:"templateSpec,omitempty"` RawTemplate string `json:"rawTemplate,omitempty"` } type WorkerSpec struct { Retain bool `json:"retain,omitempty"` - GoTemplate GoTemplate `json:"goTemplate,omitempty"` + GoTemplate *GoTemplate `json:"goTemplate,omitempty"` } type MetricsCollectorSpec struct { Retain bool `json:"retain,omitempty"` - GoTemplate GoTemplate `json:"goTemplate,omitempty"` + GoTemplate *GoTemplate `json:"goTemplate,omitempty"` } type ServiceParameter struct { diff --git a/pkg/controller/studyjob/manifest_parser.go b/pkg/controller/studyjob/manifest_parser.go index 3689b74a50e..bce8f6deb4a 100644 --- a/pkg/controller/studyjob/manifest_parser.go +++ b/pkg/controller/studyjob/manifest_parser.go @@ -17,7 +17,6 @@ package studyjob import ( "bytes" "context" - "fmt" "text/template" katibapi "github.com/kubeflow/katib/pkg/api" @@ -28,30 +27,50 @@ import ( "k8s.io/apimachinery/pkg/util/uuid" ) -func getWorkerManifest(c katibapi.ManagerClient, studyID string, trial *katibapi.Trial, workerSpec *katibv1alpha1.WorkerSpec, kind string, ns string, dryrun bool) (string, *bytes.Buffer, error) { - var wtp *template.Template = nil - var err error - if workerSpec != nil && workerSpec.GoTemplate.RawTemplate != "" { - wtp, err = template.New("Worker").Parse(workerSpec.GoTemplate.RawTemplate) +func getTemplateStrFromConfigMap(cNamespace, cName, tPath string) (string, error) { + sjc, err := studyjobclient.NewStudyjobClient(nil) + if err != nil { + return "", err + } + return sjc.GetTemplate(cNamespace, cName, tPath) +} + +func getTemplateStr(goTemplate *katibv1alpha1.GoTemplate, getDefaultTemplateSpec func()(string,string,string)) (string, error) { + if goTemplate.RawTemplate != "" { + return goTemplate.RawTemplate, nil } else { - wPath := "defaultWorkerTemplate.yaml" - if workerSpec != nil && workerSpec.GoTemplate.TemplatePath != "" { - wPath = workerSpec.GoTemplate.TemplatePath - } - sjc, err := studyjobclient.NewStudyjobClient(nil) - if err != nil { - return "", nil, err - } - wtl, err := sjc.GetWorkerTemplates() - if err != nil { - return "", nil, err - } - if wt, ok := wtl[wPath]; !ok { - return "", nil, fmt.Errorf("No worker template name %s", wPath) - } else { - wtp, err = template.New("Worker").Parse(wt) + tName, tNamespace, tPath := getDefaultTemplateSpec() + if goTemplate.TemplateSpec != nil { + tName = goTemplate.TemplateSpec.ConfigMapName + tNamespace = goTemplate.TemplateSpec.ConfigMapNamespace + tPath = goTemplate.TemplateSpec.TemplatePath } + return getTemplateStrFromConfigMap(tNamespace, tName, tPath) + } +} + +func getDefaultWorkerTemplateSpec() (string, string, string) { + return getKatibNamespace(), "worker-template", "defaultWorkerTemplate.yaml" +} + +func getDefaultMetricsTemplateSpec() (string, string, string) { + return getKatibNamespace(), "metricscollector-template", "defaultMetricsCollectorTemplate.yaml" +} + +func getWorkerTemplateStr(workerSpec *katibv1alpha1.WorkerSpec) (string, error) { + if workerSpec == nil || workerSpec.GoTemplate == nil { + return getTemplateStrFromConfigMap(getDefaultWorkerTemplateSpec()) + } else { + return getTemplateStr(workerSpec.GoTemplate, getDefaultWorkerTemplateSpec) } +} + +func getWorkerManifest(c katibapi.ManagerClient, studyID string, trial *katibapi.Trial, workerSpec *katibv1alpha1.WorkerSpec, kind string, ns string, dryrun bool) (string, *bytes.Buffer, error) { + wStr, err := getWorkerTemplateStr(workerSpec) + if err != nil { + return "", nil, err + } + wtp, err := template.New("Worker").Parse(wStr) if err != nil { return "", nil, err } @@ -91,9 +110,15 @@ func getWorkerManifest(c katibapi.ManagerClient, studyID string, trial *katibapi return wid, &b, nil } +func getMetricsCollectorTemplateStr(mcs *katibv1alpha1.MetricsCollectorSpec) (string, error) { + if mcs == nil || mcs.GoTemplate == nil { + return getTemplateStrFromConfigMap(getDefaultMetricsTemplateSpec()) + } else { + return getTemplateStr(mcs.GoTemplate, getDefaultMetricsTemplateSpec) + } +} + func getMetricsCollectorManifest(studyID string, trialID string, workerID string, workerKind string, namespace string, mcs *katibv1alpha1.MetricsCollectorSpec) (*bytes.Buffer, error) { - var mtp *template.Template = nil - var err error tmpValues := map[string]string{ "StudyID": studyID, "TrialID": trialID, @@ -102,27 +127,11 @@ func getMetricsCollectorManifest(studyID string, trialID string, workerID string "NameSpace": namespace, "ManagerSerivce": pkg.GetManagerAddr(), } - if mcs != nil && mcs.GoTemplate.RawTemplate != "" { - mtp, err = template.New("MetricsCollector").Parse(mcs.GoTemplate.RawTemplate) - } else { - mctp := "defaultMetricsCollectorTemplate.yaml" - if mcs != nil && mcs.GoTemplate.TemplatePath != "" { - mctp = mcs.GoTemplate.TemplatePath - } - sjc, err := studyjobclient.NewStudyjobClient(nil) - if err != nil { - return nil, err - } - mtl, err := sjc.GetMetricsCollectorTemplates() - if err != nil { - return nil, err - } - if mt, ok := mtl[mctp]; !ok { - return nil, fmt.Errorf("No MetricsCollector template name %s", mctp) - } else { - mtp, err = template.New("MetricsCollector").Parse(mt) - } + tStr, err := getMetricsCollectorTemplateStr(mcs) + if err != nil { + return nil, err } + mtp, err := template.New("MetricsCollector").Parse(tStr) if err != nil { return nil, err } diff --git a/pkg/controller/studyjob/utils.go b/pkg/controller/studyjob/utils.go index b0fae6e6d9c..430744aea98 100644 --- a/pkg/controller/studyjob/utils.go +++ b/pkg/controller/studyjob/utils.go @@ -13,7 +13,9 @@ package studyjob import ( "fmt" + "io/ioutil" "log" + "strings" katibapi "github.com/kubeflow/katib/pkg/api" katibv1alpha1 "github.com/kubeflow/katib/pkg/api/operators/apis/studyjob/v1alpha1" @@ -209,3 +211,8 @@ func contains(l []string, s string) bool { } return false } + +func getKatibNamespace() string { + data, _ := ioutil.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace") + return strings.TrimSpace(string(data)) +} \ No newline at end of file diff --git a/pkg/manager/studyjobclient/studyjobclient.go b/pkg/manager/studyjobclient/studyjobclient.go index 953a583adca..2e79e4debd5 100644 --- a/pkg/manager/studyjobclient/studyjobclient.go +++ b/pkg/manager/studyjobclient/studyjobclient.go @@ -1,6 +1,7 @@ package studyjobclient import ( + "fmt" "io/ioutil" "strings" @@ -102,6 +103,18 @@ func (s *StudyjobClient) GetMetricsCollectorTemplates(namespace ...string) (map[ return cm.Data, nil } +func (s *StudyjobClient) GetTemplate(namespace, name, path string) (string, error) { + cm, err := s.clientset.CoreV1().ConfigMaps(namespace).Get(name, metav1.GetOptions{}) + if err != nil { + return "", err + } + if _, ok := cm.Data[path]; !ok { + return "", fmt.Errorf("No tamplate name %s in configMap %s/%s", path, namespace, name) + } else { + return cm.Data[path], nil + } +} + func (s *StudyjobClient) UpdateMetricsCollectorTemplates(newMCTemplates map[string]string, namespace ...string) error { ns := getNamespace(namespace...) cm, err := s.clientset.CoreV1().ConfigMaps(ns).Get("metricscollector-template", metav1.GetOptions{})