diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index 6e7157a622f7c..d6e48b6d2d6ff 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -22,6 +22,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -187,8 +188,14 @@ func Ping(cfg *Config) (*PingResponse, error) { if err != nil { return nil, trace.Wrap(err) } - defer resp.Body.Close() + if resp.StatusCode == http.StatusBadRequest { + per := &PingErrorResponse{} + if err := json.NewDecoder(resp.Body).Decode(per); err != nil { + return nil, trace.Wrap(err) + } + return nil, errors.New(per.Error.Message) + } pr := &PingResponse{} if err := json.NewDecoder(resp.Body).Decode(pr); err != nil { return nil, trace.Wrap(err) @@ -265,6 +272,17 @@ type PingResponse struct { MinClientVersion string `json:"min_client_version"` } +// PingErrorResponse contains the error message if the requested connector +// does not match one that has been registered. +type PingErrorResponse struct { + Error PingError `json:"error"` +} + +// PingError contains the string message from the PingErrorResponse +type PingError struct { + Message string `json:"message"` +} + // ProxySettings contains basic information about proxy settings type ProxySettings struct { // Kube is a kubernetes specific proxy section diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index b77a090596857..05ee6edc60f5f 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -811,31 +811,53 @@ func (h *Handler) pingWithConnector(w http.ResponseWriter, r *http.Request, p ht return response, nil } + // collectorNames stores a list of the registered collector names so that + // in the event that no connector has matched, the list can be returned. + var collectorNames []string + // first look for a oidc connector with that name - oidcConnector, err := authClient.GetOIDCConnector(r.Context(), connectorName, false) + oidcConnectors, err := authClient.GetOIDCConnectors(r.Context(), false) if err == nil { - response.Auth = oidcSettings(oidcConnector, cap) - response.Auth.HasMessageOfTheDay = hasMessageOfTheDay - return response, nil + for index, value := range oidcConnectors { + collectorNames = append(collectorNames, value.GetMetadata().Name) + if value.GetMetadata().Name == connectorName { + response.Auth = oidcSettings(oidcConnectors[index], cap) + response.Auth.HasMessageOfTheDay = hasMessageOfTheDay + return response, nil + } + } } // if no oidc connector was found, look for a saml connector - samlConnector, err := authClient.GetSAMLConnector(r.Context(), connectorName, false) + samlConnectors, err := authClient.GetSAMLConnectors(r.Context(), false) if err == nil { - response.Auth = samlSettings(samlConnector, cap) - response.Auth.HasMessageOfTheDay = hasMessageOfTheDay - return response, nil + for index, value := range samlConnectors { + collectorNames = append(collectorNames, value.GetMetadata().Name) + if value.GetMetadata().Name == connectorName { + response.Auth = samlSettings(samlConnectors[index], cap) + response.Auth.HasMessageOfTheDay = hasMessageOfTheDay + return response, nil + } + } } // look for github connector - githubConnector, err := authClient.GetGithubConnector(r.Context(), connectorName, false) + githubConnectors, err := authClient.GetGithubConnectors(r.Context(), false) if err == nil { - response.Auth = githubSettings(githubConnector, cap) - response.Auth.HasMessageOfTheDay = hasMessageOfTheDay - return response, nil + for index, value := range githubConnectors { + collectorNames = append(collectorNames, value.GetMetadata().Name) + if value.GetMetadata().Name == connectorName { + response.Auth = githubSettings(githubConnectors[index], cap) + response.Auth.HasMessageOfTheDay = hasMessageOfTheDay + return response, nil + } + } } - return nil, trace.BadParameter("invalid connector name %v", connectorName) + return nil, + trace.BadParameter( + "invalid connector name: %v; valid options: %s", + connectorName, strings.Join(collectorNames, ", ")) } // getWebConfig returns configuration for the web application. diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 810c8b15eee36..a826c690216c1 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -399,7 +399,7 @@ func Run(args []string, opts ...cliOption) error { BoolVar(&cf.InsecureSkipVerify) } - app.Flag("auth", "Specify the type of authentication connector to use.").Envar(authEnvVar).StringVar(&cf.AuthConnector) + app.Flag("auth", "Specify the name of authentication connector to use.").Envar(authEnvVar).StringVar(&cf.AuthConnector) app.Flag("namespace", "Namespace of the cluster").Default(apidefaults.Namespace).Hidden().StringVar(&cf.Namespace) app.Flag("gops", "Start gops endpoint on a given address").Hidden().BoolVar(&cf.Gops) app.Flag("gops-addr", "Specify gops addr to listen on").Hidden().StringVar(&cf.GopsAddr)