-
Notifications
You must be signed in to change notification settings - Fork 670
/
root.go
191 lines (169 loc) · 6.71 KB
/
root.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
package cmd
import (
"context"
"flag"
"fmt"
"os"
"runtime"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/config/viper"
"github.com/flyteorg/flyte/flytestdlib/contextutils"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/promutils/labeled"
"github.com/flyteorg/flyte/flytestdlib/storage"
"github.com/flyteorg/flyte/flytestdlib/version"
)
type RootOptions struct {
*clientcmd.ConfigOverrides
showSource bool
clientConfig clientcmd.ClientConfig
Scope promutils.Scope
Store *storage.DataStore
configAccessor config.Accessor
cfgFile string
// The actual key name that should be created under the remote prefix where the error document is written of the form errors.pb
errorOutputName string
}
func (r *RootOptions) executeRootCmd() error {
ctx := context.TODO()
logger.Infof(ctx, "Go Version: %s", runtime.Version())
logger.Infof(ctx, "Go OS/Arch: %s/%s", runtime.GOOS, runtime.GOARCH)
version.LogBuildInformation("flytedata")
return fmt.Errorf("use one of the sub-commands")
}
func (r RootOptions) UploadError(ctx context.Context, code string, recvErr error, prefix storage.DataReference) error {
if recvErr == nil {
recvErr = fmt.Errorf("unknown error")
}
errorPath, err := r.Store.ConstructReference(ctx, prefix, r.errorOutputName)
if err != nil {
logger.Errorf(ctx, "failed to create error file path err: %s", err)
return err
}
logger.Infof(ctx, "Uploading Error file to path [%s], errFile: %s", errorPath, r.errorOutputName)
return r.Store.WriteProtobuf(ctx, errorPath, storage.Options{}, &core.ErrorDocument{
Error: &core.ContainerError{
Code: code,
Message: recvErr.Error(),
Kind: core.ContainerError_RECOVERABLE,
},
})
}
func PollUntilTimeout(ctx context.Context, pollInterval, timeout time.Duration, condition wait.ConditionFunc) error {
childCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return wait.PollUntil(pollInterval, condition, childCtx.Done())
}
func checkAWSCreds() (*credentials.Value, error) {
sess, err := session.NewSession(&aws.Config{})
if err != nil {
return nil, err
}
// Determine the AWS credentials from the default credential chain
creds, err := sess.Config.Credentials.Get()
if err != nil {
return nil, err
}
if creds.AccessKeyID == "" || creds.SecretAccessKey == "" || creds.SessionToken == "" {
return nil, fmt.Errorf("invalid data in credential fetch")
}
return &creds, nil
}
func waitForAWSCreds(ctx context.Context, timeout time.Duration) error {
return PollUntilTimeout(ctx, time.Second*5, timeout, func() (bool, error) {
if creds, err := checkAWSCreds(); err != nil {
logger.Errorf(ctx, "failed to get AWS credentials: %s", err)
return false, nil
} else if creds != nil {
logger.Infof(ctx, "found AWS credentials from provider: %s", creds.ProviderName)
}
return true, nil
})
}
// NewCommand returns a new instance of the co-pilot root command
func NewDataCommand() *cobra.Command {
rootOpts := &RootOptions{}
command := &cobra.Command{
Use: "flytedata",
Short: "flytedata is a simple go binary that can be used to retrieve and upload data from/to remote stow store to local disk.",
Long: `flytedata when used with conjunction with flytepropeller eliminates the need to have any flyte library installed inside the container`,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
if err := rootOpts.initConfig(cmd, args); err != nil {
return err
}
rootOpts.Scope = promutils.NewScope("flyte:data")
cfg := storage.GetConfig()
if cfg.Type == storage.TypeS3 {
if err := waitForAWSCreds(context.Background(), time.Minute*10); err != nil {
return err
}
}
store, err := storage.NewDataStore(cfg, rootOpts.Scope)
if err != nil {
return errors.Wrap(err, "failed to create datastore client")
}
rootOpts.Store = store
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
return rootOpts.executeRootCmd()
},
}
command.AddCommand(NewDownloadCommand(rootOpts))
command.AddCommand(NewUploadCommand(rootOpts))
loadingRules := clientcmd.NewDefaultClientConfigLoadingRules()
loadingRules.DefaultClientConfig = &clientcmd.DefaultClientConfig
rootOpts.ConfigOverrides = &clientcmd.ConfigOverrides{}
kflags := clientcmd.RecommendedConfigOverrideFlags("")
command.PersistentFlags().StringVar(&loadingRules.ExplicitPath, "kubeconfig", "", "Path to a kube config. Only required if out-of-cluster")
clientcmd.BindOverrideFlags(rootOpts.ConfigOverrides, command.PersistentFlags(), kflags)
rootOpts.clientConfig = clientcmd.NewInteractiveDeferredLoadingClientConfig(loadingRules, rootOpts.ConfigOverrides, os.Stdin)
command.PersistentFlags().StringVar(&rootOpts.cfgFile, "config", "", "config file (default is $HOME/config.yaml)")
command.PersistentFlags().BoolVarP(&rootOpts.showSource, "show-source", "s", false, "Show line number for errors")
command.PersistentFlags().StringVar(&rootOpts.errorOutputName, "err-output-name", "errors.pb", "Actual key name under the prefix where the error protobuf should be written to")
rootOpts.configAccessor = viper.NewAccessor(config.Options{StrictMode: true})
// Here you will define your flags and configuration settings. Cobra supports persistent flags, which, if defined
// here, will be global for your application.
rootOpts.configAccessor.InitializePflags(command.PersistentFlags())
command.AddCommand(viper.GetConfigCommand())
return command
}
func (r *RootOptions) initConfig(cmd *cobra.Command, _ []string) error {
r.configAccessor = viper.NewAccessor(config.Options{
StrictMode: true,
SearchPaths: []string{r.cfgFile},
})
rootCmd := cmd
for rootCmd.Parent() != nil {
rootCmd = rootCmd.Parent()
}
// persistent flags were initially bound to the root command so we must bind to the same command to avoid
r.configAccessor.InitializePflags(rootCmd.PersistentFlags())
err := r.configAccessor.UpdateConfig(context.TODO())
if err != nil {
return err
}
return nil
}
func init() {
klog.InitFlags(flag.CommandLine)
pflag.CommandLine.AddGoFlagSet(flag.CommandLine)
err := flag.CommandLine.Parse([]string{})
if err != nil {
logger.Errorf(context.TODO(), "Error in initializing: %v", err)
os.Exit(-1)
}
labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey)
}