diff --git a/cmds/init.go b/cmds/init.go index f0a6dbb..0fe62d9 100644 --- a/cmds/init.go +++ b/cmds/init.go @@ -9,13 +9,13 @@ import ( "text/template" "github.com/spf13/cobra" - "github.com/yangchnet/pm/remote" "github.com/yangchnet/pm/utils" ) func InitCmd() *cobra.Command { var ( - onlyLocal bool + storeType string + remoteType string ) var initCmd = &cobra.Command{ @@ -36,23 +36,26 @@ func InitCmd() *cobra.Command { utils.CreateDirIfNotExist(pmDir) // 创建.pm主目录 utils.CreateDirIfNotExist(storeDir) // 创建存储目录 - var remoteType string = "git" - if len(args) > 0 { - remoteType = args[0] + remote, err := NewRemote(cmd.Context(), remoteType, nil) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) } - if onlyLocal { - remoteType = "empty" + // remote 初始化 + remoteConfigStr, err := remote.Init(cmd.Context()) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) } - remote, err := remote.NewRemote(cmd.Context(), remoteType, nil) + localStore, err := NewStore(cmd.Context(), storeType, nil) if err != nil { fmt.Println("Error:", err) os.Exit(1) } - // remote 初始化 - remoteConfigStr, err := remote.Init(cmd.Context()) + storeConfigStr, err := localStore.Init(cmd.Context()) if err != nil { fmt.Println("Error:", err) os.Exit(1) @@ -61,6 +64,7 @@ func InitCmd() *cobra.Command { // 写入配置文件 params := map[string]string{ "remoteConfig": remoteConfigStr, + "storeConfig": storeConfigStr, "localPath": filepath.Join(home, ".pm/store"), "userKeyPath": userKeyPath, } @@ -90,7 +94,9 @@ func InitCmd() *cobra.Command { }, } - initCmd.Flags().BoolVarP(&onlyLocal, "only-local", "", false, "only use local storage") + initCmd.Flags().StringVarP(&storeType, "store", "s", "file", "store type: [file, sqlite], default file") + + initCmd.Flags().StringVarP(&remoteType, "remote", "r", "empty", "remote type: [git, empty], default empty") return initCmd } @@ -98,6 +104,8 @@ func InitCmd() *cobra.Command { var confTmpl = ` {{.remoteConfig}} +{{.storeConfig}} + local: path: {{.localPath}} diff --git a/cmds/new.go b/cmds/new.go index 8785676..1e010d0 100644 --- a/cmds/new.go +++ b/cmds/new.go @@ -61,6 +61,7 @@ func GenerateCmd() *cobra.Command { } clipboard.WriteAll(password) + fmt.Println("密码已经复制到剪贴板") }, PostRun: func(cmd *cobra.Command, args []string) { service, err := NewService(cmd.Context()) diff --git a/cmds/service.go b/cmds/service.go index 0bccbed..592a5f5 100644 --- a/cmds/service.go +++ b/cmds/service.go @@ -6,7 +6,11 @@ import ( "github.com/yangchnet/pm/config" "github.com/yangchnet/pm/remote" + "github.com/yangchnet/pm/remote/empty" + gitremote "github.com/yangchnet/pm/remote/git" "github.com/yangchnet/pm/store" + filestore "github.com/yangchnet/pm/store/file-store" + sqlitestore "github.com/yangchnet/pm/store/sqlite-store" ) type service struct { @@ -20,12 +24,51 @@ func NewService(ctx context.Context) (*service, error) { return nil, fmt.Errorf("remote not found") } - remote, err := remote.NewRemote(ctx, remoteMap["type"].(string), remoteMap) + remote, err := NewRemote(ctx, remoteMap["type"].(string), remoteMap) if err != nil { return nil, err } + + storeMap := config.GetStringMap("store") + if len(storeMap) <= 0 { + return nil, fmt.Errorf("store not found") + } + + store, err := NewStore(ctx, storeMap["type"].(string), storeMap) + if err != nil { + return nil, err + } + return &service{ - store: store.NewSqliteStore(ctx), + store: store, remote: remote, }, nil } + +func NewStore(ctx context.Context, storeType string, storeConfig map[string]any) (store.Store, error) { + var localStore store.Store + switch storeType { + case "sqlite": + localStore = sqlitestore.NewSqliteStore(ctx) + case "file": + localStore = filestore.NewFileStore(ctx) + default: + return nil, fmt.Errorf("未知的store类型: %s", storeType) + } + + return localStore, nil +} + +func NewRemote(ctx context.Context, remoteType string, remoteMap map[string]any) (remote.Remote, error) { + var remote remote.Remote + switch remoteType { + case "git": + remote = gitremote.NewGitRemote(ctx, remoteMap) + case "empty": + remote = empty.NewEmptyRemote() + default: + return nil, fmt.Errorf("未知的remote类型: %s", remoteType) + } + + return remote, nil +} diff --git a/remote/remote.go b/remote/remote.go index f1d3143..9c402ad 100644 --- a/remote/remote.go +++ b/remote/remote.go @@ -2,10 +2,6 @@ package remote import ( "context" - "fmt" - - "github.com/yangchnet/pm/remote/empty" - gitremote "github.com/yangchnet/pm/remote/git" ) type Remote interface { @@ -18,17 +14,3 @@ type Remote interface { // Init 初始化remote,返回remote配置信息 Init(ctx context.Context) (string, error) } - -func NewRemote(ctx context.Context, remoteType string, remoteMap map[string]any) (Remote, error) { - var remote Remote - switch remoteType { - case "git": - remote = gitremote.NewGitRemote(ctx, remoteMap) - case "empty": - remote = empty.NewEmptyRemote() - default: - return nil, fmt.Errorf("未知的remote类型: %s", remoteType) - } - - return remote, nil -} diff --git a/store/err.go b/store/err.go new file mode 100644 index 0000000..83fa4eb --- /dev/null +++ b/store/err.go @@ -0,0 +1,8 @@ +package store + +import "errors" + +var ( + ErrAlreadyExists = errors.New("already exists") + ErrNotFound = errors.New("not found") +) diff --git a/store/file-store/file.go b/store/file-store/file.go new file mode 100644 index 0000000..1a9b17c --- /dev/null +++ b/store/file-store/file.go @@ -0,0 +1,138 @@ +package filestore + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/yangchnet/pm/config" + "github.com/yangchnet/pm/store" +) + +type FileStore struct { + localPath string +} + +func NewFileStore(ctx context.Context) *FileStore { + return &FileStore{ + localPath: config.GetString("local.path"), + } +} + +var _ store.Store = &FileStore{} + +func (s *FileStore) Init(ctx context.Context) (string, error) { + return `store: + type: file`, nil +} + +// Save 在使用cryptFunc对密码密文进行存储 +func (s *FileStore) Save(ctx context.Context, passwd *store.Passwd) error { + files, err := readAllPasswd(s.localPath) + if err != nil { + return err + } + + _, ok := files[passwd.Name+".passwd"] + if ok { + return store.ErrAlreadyExists + } + + f, err := os.Create(filepath.Join(s.localPath, passwd.Name+".passwd")) + if err != nil { + return err + } + defer f.Close() + + passwdByte, err := json.Marshal(passwd) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + encoded := base64.StdEncoding.EncodeToString(passwdByte) + _, err = f.WriteString(encoded) + if err != nil { + return err + } + + return nil +} + +// Get 获取密码 +func (s *FileStore) Get(ctx context.Context, name string) (*store.Passwd, error) { + files, err := readAllPasswd(s.localPath) + if err != nil { + return nil, err + } + + path, ok := files[name+".passwd"] + if !ok { + return nil, store.ErrNotFound + } + + f, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + decoded, err := base64.StdEncoding.DecodeString(string(f)) + if err != nil { + return nil, err + } + + var passwd store.Passwd + if err := json.Unmarshal(decoded, &passwd); err != nil { + return nil, err + } + + return &passwd, nil +} + +// SearchName 根据名称进行搜索并给出名称列表 +func (s *FileStore) SearchName(ctx context.Context, name string) ([]string, error) { + files, err := readAllPasswd(s.localPath) + if err != nil { + return nil, err + } + + var names []string + for k, _ := range files { + list := strings.Split(k, ".") + if len(list) < 1 { + continue + } + + if strings.Contains(strings.ToLower(list[0]), strings.ToLower(name)) { + names = append(names, list[0]) + } + } + + return names, nil +} + +// Delete 删除一个记录 +func (s *FileStore) Delete(ctx context.Context, name string) error { + return os.Remove(filepath.Join(s.localPath, name+".passwd")) +} + +func readAllPasswd(dir string) (map[string]string, error) { + files := make(map[string]string) + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && filepath.Ext(path) == ".passwd" { + files[info.Name()] = path + } + return nil + }) + if err != nil { + return nil, err + } + return files, nil +} diff --git a/store/interface.go b/store/interface.go index 4f6a2de..47af803 100644 --- a/store/interface.go +++ b/store/interface.go @@ -5,6 +5,9 @@ import ( ) type Store interface { + // Init 初始化存储 + Init(ctx context.Context) (string, error) + // Save 在使用cryptFunc对密码密文进行存储 Save(ctx context.Context, passwd *Passwd) error diff --git a/store/sqlite.go b/store/sqlite-store/sqlite.go similarity index 65% rename from store/sqlite.go rename to store/sqlite-store/sqlite.go index f4dafab..50ca015 100644 --- a/store/sqlite.go +++ b/store/sqlite-store/sqlite.go @@ -1,4 +1,4 @@ -package store +package sqlitestore import ( "context" @@ -6,6 +6,7 @@ import ( "time" "github.com/yangchnet/pm/config" + "github.com/yangchnet/pm/store" "gorm.io/driver/sqlite" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" @@ -24,24 +25,29 @@ func NewSqliteStore(ctx context.Context) *SqliteStore { } // 迁移 schema - db.AutoMigrate(&Passwd{}) + db.AutoMigrate(&store.Passwd{}) return &SqliteStore{ db: db, } } +func (s *SqliteStore) Init(ctx context.Context) (string, error) { + return `store: + type: sqlite`, nil +} + // Save 在使用cryptFunc对密码密文进行存储 -func (s *SqliteStore) Save(ctx context.Context, passwd *Passwd) error { +func (s *SqliteStore) Save(ctx context.Context, passwd *store.Passwd) error { passwd.CreateTime = time.Now() passwd.UpdateTime = time.Now() return s.db.Save(passwd).Error } // Get 获取密码密文 -func (s *SqliteStore) Get(ctx context.Context, name string) (*Passwd, error) { - var passwd *Passwd - if err := s.db.Model(&Passwd{}).Where("name = ?", name).First(&passwd).Error; err != nil { +func (s *SqliteStore) Get(ctx context.Context, name string) (*store.Passwd, error) { + var passwd *store.Passwd + if err := s.db.Model(&store.Passwd{}).Where("name = ?", name).First(&passwd).Error; err != nil { return nil, err } @@ -51,7 +57,7 @@ func (s *SqliteStore) Get(ctx context.Context, name string) (*Passwd, error) { // SearchName 根据名称进行搜索并给出名称列表 func (s *SqliteStore) SearchName(ctx context.Context, name string) ([]string, error) { var names []string - if err := s.db.Model(&Passwd{}).Where("name LIKE ?", "%"+name+"%").Select("name").Scan(&names).Error; err != nil { + if err := s.db.Model(&store.Passwd{}).Where("name LIKE ?", "%"+name+"%").Select("name").Scan(&names).Error; err != nil { return nil, err } return names, nil @@ -59,7 +65,7 @@ func (s *SqliteStore) SearchName(ctx context.Context, name string) ([]string, er // Delete 删除一个记录 func (s *SqliteStore) Delete(ctx context.Context, name string) error { - var passwd Passwd + var passwd store.Passwd if err := s.db.Where("name = ?", name).First(&passwd).Error; err != nil { return err }