Skip to content

Commit

Permalink
Refactor init config
Browse files Browse the repository at this point in the history
  • Loading branch information
moshe-kabala committed Aug 1, 2024
1 parent c741b21 commit dcc40b2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 28 deletions.
7 changes: 2 additions & 5 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ Complete documentation is available at https://docs.tensorleap.ai`,
}

func init() {
cobra.OnInitialize(initConfig)

RootCommand.Flags().BoolVar(&showVersionInfo, "version", false, "Show version information")
RootCommand.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.config/tensorleap/config.yaml)")
Expand All @@ -72,6 +71,8 @@ func init() {
if hubPkg.IsHubEnabled() {
RootCommand.AddCommand(hub.RootCommand)
}

initConfig()
}

func Execute() {
Expand All @@ -82,10 +83,6 @@ func Execute() {
}

func initConfig() {
cfgFileFromEnv := os.Getenv("TL_CLI_CONFIG_FILE")
if cfgFile == "" && cfgFileFromEnv != "" {
cfgFile = cfgFileFromEnv
}
err := config.InitConfig(cfgFile)
cobra.CheckErr(err)
}
82 changes: 59 additions & 23 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,78 @@ import (
"fmt"
"os"
"path"
"strings"

"github.com/spf13/viper"
"github.com/tensorleap/leap-cli/pkg/log"
)

const TL_CLI_CONFIG_FILE = "TL_CLI_CONFIG_FILE"

func getDefaultConfigPath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("error getting home directory: %v", err)
}
return path.Join(homeDir, ".config", "tensorleap/config.yaml"), nil
}

func validateConfigPath(cfgFile string) error {
if cfgFile == "" {
return fmt.Errorf("config file is required")
}

if !path.IsAbs(cfgFile) {
return fmt.Errorf("cli config file must be an absolute path, current path: %s", cfgFile)
}
return nil
}

func InitConfig(cfgFile string) error {
if cfgFile != "" {
viper.SetConfigFile(cfgFile)
} else {
homeDir, err := os.UserHomeDir()
cfgFileFromEnv := os.Getenv(TL_CLI_CONFIG_FILE)

if cfgFile == "" && cfgFileFromEnv != "" {
cfgFile = cfgFileFromEnv
} else if cfgFile == "" {
var err error
cfgFile, err = getDefaultConfigPath()
if err != nil {
return err
}
}
err := validateConfigPath(cfgFile)
if err != nil {
return err
}

configDir := path.Join(homeDir, ".config", "tensorleap")
configName := "config"
configType := "yaml"
configPath := path.Join(configDir, fmt.Sprintf("%s.%s", configName, configType))
if err := os.MkdirAll(configDir, os.ModePerm); err != nil {
return err
}
if _, err := os.Stat(configPath); os.IsNotExist(err) {
file, err := os.Create(configPath)
if err != nil {
return fmt.Errorf("error creating config the file: %v", err)
}
file.Close()
}
viper.AddConfigPath(configDir)
viper.SetConfigType(configType)
viper.SetConfigName(configName)
cfgFile = strings.TrimSuffix(cfgFile, path.Ext(cfgFile))

configDir := path.Dir(cfgFile)
configName := path.Base(cfgFile)
configType := "yaml"
configPath := path.Join(configDir, fmt.Sprintf("%s.%s", configName, configType))
if err := os.MkdirAll(configDir, os.ModePerm); err != nil {
return err
}
if _, err := os.Stat(configPath); os.IsNotExist(err) {
file, err := os.Create(configPath)
if err != nil {
return fmt.Errorf("error creating config the file: %v", err)
}
file.Close()
}

viper.AutomaticEnv()
return viper.ReadInConfig()
viper.AddConfigPath(configDir)
viper.SetConfigType(configType)
viper.SetConfigName(configName)

viper.AutomaticEnv()
err = viper.ReadInConfig()
if err != nil {
return err
}
log.Infof("Using config file: %s", viper.ConfigFileUsed())
return nil
}

func Save() error {
Expand Down

0 comments on commit dcc40b2

Please sign in to comment.