diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index c156cb81fd6c..09a67468efb0 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -23,6 +23,7 @@ go_library( "//pkg/ccl/sqlproxyccl/throttler", "//pkg/roachpb:with-mocks", "//pkg/security/certmgr", + "//pkg/sql/pgwire/pgcode", "//pkg/util/contextutil", "//pkg/util/grpcutil", "//pkg/util/httputil", @@ -68,6 +69,7 @@ go_test( "//pkg/server", "//pkg/sql", "//pkg/sql/pgwire", + "//pkg/sql/pgwire/pgerror", "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/skip", diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index fe09dcf7fb57..00ac6de13dfa 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -12,6 +12,7 @@ import ( "io" "net" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" ) @@ -50,20 +51,22 @@ func toPgError(err error) *pgproto3.ErrorResponse { var pgCode string if codeErr.code == codeIdleDisconnect { - pgCode = "57P01" // admin shutdown + pgCode = pgcode.AdminShutdown.String() } else { - pgCode = "08004" // rejected connection + pgCode = pgcode.SQLserverRejectedEstablishmentOfSQLconnection.String() } + return &pgproto3.ErrorResponse{ Severity: "FATAL", Code: pgCode, Message: msg, + Hint: errors.FlattenHints(codeErr.err), } } // Return a generic "internal server error" message. return &pgproto3.ErrorResponse{ Severity: "FATAL", - Code: "08004", // rejected connection + Code: pgcode.SQLserverRejectedEstablishmentOfSQLconnection.String(), Message: "internal server error", } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index ad6a5e0b0940..cff6c9fd7f8e 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -42,7 +42,7 @@ var ( // Unlike the original spec, this does not handle escaping rules. // // See "options" in https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS. - clusterNameLongOptionRE = regexp.MustCompile(`(?:-c\s*|--)cluster=([\S]*)`) + clusterIdentifierLongOptionRE = regexp.MustCompile(`(?:-c\s*|--)cluster=([\S]*)`) // clusterNameRegex restricts cluster names to have between 6 and 20 // alphanumeric characters, with dashes allowed within the name (but not as a @@ -51,10 +51,9 @@ var ( ) const ( - // Cluster identifier is in the form "clustername-. Tenant id is - // always in the end but the cluster name can also contain '-' or digits. - // For example: - // "foo-7-10" -> cluster name is "foo-7" and tenant id is 10. + // Cluster identifier is in the form "-. Tenant ID + // is always in the end but the cluster name can also contain '-' or digits. + // (e.g. In "foo-7-10", cluster name is "foo-7" and tenant ID is "10") clusterTenantSep = "-" ) @@ -566,66 +565,95 @@ var reportFailureToDirectory = func( return directory.ReportFailure(ctx, tenantID, addr) } -// clusterNameAndTenantFromParams extracts the cluster name from the connection -// parameters, and rewrites the database param, if necessary. We currently -// support embedding the cluster name in two ways: -// - Within the database param (e.g. "happy-koala.defaultdb") +// clusterNameAndTenantFromParams extracts the cluster name and tenant ID from +// the connection parameters, and rewrites the database and options parameters, +// if necessary. // -// - Within the options param (e.g. "... --cluster=happy-koala ..."). +// We currently support embedding the cluster identifier in two ways: +// +// - Within the database param (e.g. "happy-koala-3.defaultdb") +// +// - Within the options param (e.g. "... --cluster=happy-koala-5 ..."). // PostgreSQL supports three different ways to set a run-time parameter // through its command-line options, i.e. "-c NAME=VALUE", "-cNAME=VALUE", and // "--NAME=VALUE". func clusterNameAndTenantFromParams( ctx context.Context, msg *pgproto3.StartupMessage, ) (*pgproto3.StartupMessage, string, roachpb.TenantID, error) { - clusterNameFromDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"]) + clusterIdentifierDB, databaseName, err := parseDatabaseParam(msg.Parameters["database"]) if err != nil { return msg, "", roachpb.MaxTenantID, err } - clusterNameFromOpt, newOptionsParam, err := parseOptionsParam(msg.Parameters["options"]) + clusterIdentifierOpt, newOptionsParam, err := parseOptionsParam(msg.Parameters["options"]) if err != nil { return msg, "", roachpb.MaxTenantID, err } - if clusterNameFromDB == "" && clusterNameFromOpt == "" { - return msg, "", roachpb.MaxTenantID, errors.New("missing cluster name in connection string") + // No cluster identifiers were specified. + if clusterIdentifierDB == "" && clusterIdentifierOpt == "" { + err := errors.New("missing cluster identifier") + err = errors.WithHint(err, clusterIdentifierHint) + return msg, "", roachpb.MaxTenantID, err } - if clusterNameFromDB != "" && clusterNameFromOpt != "" { - return msg, "", roachpb.MaxTenantID, errors.New("multiple cluster names provided") + // Ambiguous cluster identifiers. + if clusterIdentifierDB != "" && clusterIdentifierOpt != "" && + clusterIdentifierDB != clusterIdentifierOpt { + err := errors.New("multiple different cluster identifiers provided") + err = errors.WithHintf(err, + "Is '%s' or '%s' the identifier for the cluster that you're connecting to?", + clusterIdentifierDB, clusterIdentifierOpt) + err = errors.WithHint(err, clusterIdentifierHint) + return msg, "", roachpb.MaxTenantID, err } - if clusterNameFromDB == "" { - clusterNameFromDB = clusterNameFromOpt + if clusterIdentifierDB == "" { + clusterIdentifierDB = clusterIdentifierOpt } - sepIdx := strings.LastIndex(clusterNameFromDB, clusterTenantSep) + sepIdx := strings.LastIndex(clusterIdentifierDB, clusterTenantSep) - // Cluster name provided without a tenant ID in the end. - if sepIdx == -1 || sepIdx == len(clusterNameFromDB)-1 { - return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) + // Cluster identifier provided without a tenant ID in the end. + if sepIdx == -1 || sepIdx == len(clusterIdentifierDB)-1 { + err := errors.Errorf("invalid cluster identifier '%s'", clusterIdentifierDB) + err = errors.WithHint(err, missingTenantIDHint) + err = errors.WithHint(err, clusterNameFormHint) + return msg, "", roachpb.MaxTenantID, err } - clusterNameSansTenant, tenantIDStr := clusterNameFromDB[:sepIdx], clusterNameFromDB[sepIdx+1:] - if !clusterNameRegex.MatchString(clusterNameSansTenant) { - return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) + clusterName, tenantIDStr := clusterIdentifierDB[:sepIdx], clusterIdentifierDB[sepIdx+1:] + + // Cluster name does not conform to the expected format (e.g. too short). + if !clusterNameRegex.MatchString(clusterName) { + err := errors.Errorf("invalid cluster identifier '%s'", clusterIdentifierDB) + err = errors.WithHintf(err, "Is '%s' a valid cluster name?", clusterName) + err = errors.WithHint(err, clusterNameFormHint) + return msg, "", roachpb.MaxTenantID, err } + // Tenant ID cannot be parsed. tenID, err := strconv.ParseUint(tenantIDStr, 10, 64) if err != nil { // Log these non user-facing errors. - log.Errorf(ctx, "cannot parse tenant ID in %s: %v", clusterNameFromDB, err) - return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) + log.Errorf(ctx, "cannot parse tenant ID in %s: %v", clusterIdentifierDB, err) + err := errors.Errorf("invalid cluster identifier '%s'", clusterIdentifierDB) + err = errors.WithHintf(err, "Is '%s' a valid tenant ID?", tenantIDStr) + err = errors.WithHint(err, clusterNameFormHint) + return msg, "", roachpb.MaxTenantID, err } + // This case only happens if tenID is 0 or 1 (system tenant). if tenID < roachpb.MinTenantID.ToUint64() { // Log these non user-facing errors. - log.Errorf(ctx, "%s contains an invalid tenant ID", clusterNameFromDB) - return msg, "", roachpb.MaxTenantID, errors.Errorf("invalid cluster name '%s'", clusterNameFromDB) + log.Errorf(ctx, "%s contains an invalid tenant ID", clusterIdentifierDB) + err := errors.Errorf("invalid cluster identifier '%s'", clusterIdentifierDB) + err = errors.WithHintf(err, "Tenant ID %d is invalid.", tenID) + return msg, "", roachpb.MaxTenantID, err } // Make and return a copy of the startup msg so the original is not modified. + // We will rewrite database and options in the new startup message. paramsOut := map[string]string{} for key, value := range msg.Parameters { if key == "database" { @@ -638,20 +666,21 @@ func clusterNameAndTenantFromParams( paramsOut[key] = value } } + outMsg := &pgproto3.StartupMessage{ ProtocolVersion: msg.ProtocolVersion, Parameters: paramsOut, } - - return outMsg, clusterNameSansTenant, roachpb.MakeTenantID(tenID), nil + return outMsg, clusterName, roachpb.MakeTenantID(tenID), nil } // parseDatabaseParam parses the database parameter from the PG connection -// string, and tries to extract the cluster name if present. The cluster -// name should be embedded in the database parameter using the dot (".") -// delimiter in the form of ".". This approach -// is safe because dots are not allowed in the database names themselves. -func parseDatabaseParam(databaseParam string) (clusterName, databaseName string, err error) { +// string, and tries to extract the cluster identifier if present. The cluster +// identifier should be embedded in the database parameter using the dot (".") +// delimiter in the form of ".". This +// approach is safe because dots are not allowed in the database names +// themselves. +func parseDatabaseParam(databaseParam string) (clusterIdentifier, databaseName string, err error) { // Database param is not provided. if databaseParam == "" { return "", "", nil @@ -664,21 +693,21 @@ func parseDatabaseParam(databaseParam string) (clusterName, databaseName string, return "", databaseParam, nil } - clusterName, databaseName = parts[0], parts[1] + clusterIdentifier, databaseName = parts[0], parts[1] // Ensure that the param is in the right format if the delimiter is provided. - if len(parts) > 2 || clusterName == "" || databaseName == "" { + if len(parts) > 2 || clusterIdentifier == "" || databaseName == "" { return "", "", errors.New("invalid database param") } - return clusterName, databaseName, nil + return clusterIdentifier, databaseName, nil } // parseOptionsParam parses the options parameter from the PG connection string, -// and tries to return the cluster name if present. It also returns the options -// parameter with the cluster name stripped out. Just like PostgreSQL, the -// sqlproxy supports three different ways to set a run-time parameter through -// its command-line options: +// and tries to return the cluster identifier if present. It also returns the +// options parameter with the cluster key stripped out. Just like PostgreSQL, +// the sqlproxy supports three different ways to set a run-time parameter +// through its command-line options: // -c NAME=VALUE (commonly used throughout documentation around PGOPTIONS) // -cNAME=VALUE // --NAME=VALUE @@ -688,22 +717,23 @@ func parseDatabaseParam(databaseParam string) (clusterName, databaseName string, // parse this, we need to traverse the string from left to right, and look at // every single argument, but that involves quite a bit of work, so we'll punt // for now. -func parseOptionsParam(optionsParam string) (clusterName, newOptionsParam string, err error) { +func parseOptionsParam(optionsParam string) (clusterIdentifier, newOptionsParam string, err error) { // Only search up to 2 in case of large inputs. - matches := clusterNameLongOptionRE.FindAllStringSubmatch(optionsParam, 2 /* n */) + matches := clusterIdentifierLongOptionRE.FindAllStringSubmatch(optionsParam, 2 /* n */) if len(matches) == 0 { return "", optionsParam, nil } if len(matches) > 1 { // Technically we could still allow requests to go through if all - // cluster names match, but we don't want to parse the entire string, so - // we will just error out if at least two cluster flags are provided. + // cluster identifiers match, but we don't want to parse the entire + // string, so we will just error out if at least two cluster flags are + // provided. return "", "", errors.New("multiple cluster flags provided") } // Length of each match should always be 2 with the given regex, one for - // the full string, and the other for the cluster name. + // the full string, and the other for the cluster identifier. if len(matches[0]) != 2 { // We don't want to panic here. return "", "", errors.New("internal server error") @@ -713,7 +743,27 @@ func parseOptionsParam(optionsParam string) (clusterName, newOptionsParam string if len(matches[0][1]) == 0 { return "", "", errors.New("invalid cluster flag") } + newOptionsParam = strings.ReplaceAll(optionsParam, matches[0][0], "") newOptionsParam = strings.TrimSpace(newOptionsParam) return matches[0][1], newOptionsParam, nil } + +const clusterIdentifierHint = `Ensure that your cluster identifier is uniquely specified using any of the +following methods: + +1) Database parameter: + Use "." as the database parameter. + (e.g. database="active-roach-42.defaultdb") + +2) Options parameter: + Use "--cluster=" as the options parameter. + (e.g. options="--cluster=active-roach-42") + +For more details, please visit our docs site at: + https://www.cockroachlabs.com/docs/cockroachcloud/connect-to-a-serverless-cluster +` + +const clusterNameFormHint = "Cluster identifiers come in the form of - (e.g. lazy-roach-3)." + +const missingTenantIDHint = "Did you forget to include your tenant ID in the cluster identifier?" diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 00c9ab787d63..b48c0d53d6bf 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -30,6 +30,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" @@ -142,21 +143,21 @@ func TestFailedConnection(t *testing.T) { // TenantID rejected as malformed. te.TestConnectErr( ctx, t, u+"?options=--cluster=dimdog&sslmode="+sslmode, - codeParamsRoutingFailed, "invalid cluster name 'dimdog'", + codeParamsRoutingFailed, "invalid cluster identifier 'dimdog'", ) require.Equal(t, int64(1+(i*3)), s.metrics.RoutingErrCount.Count()) // No cluster name and TenantID. te.TestConnectErr( ctx, t, u+"?sslmode="+sslmode, - codeParamsRoutingFailed, "missing cluster name in connection string", + codeParamsRoutingFailed, "missing cluster identifier", ) require.Equal(t, int64(2+(i*3)), s.metrics.RoutingErrCount.Count()) // Bad TenantID. Ensure that we don't leak any parsing errors. te.TestConnectErr( ctx, t, u+"?options=--cluster=dim-dog-foo3&sslmode="+sslmode, - codeParamsRoutingFailed, "invalid cluster name 'dim-dog-foo3'", + codeParamsRoutingFailed, "invalid cluster identifier 'dim-dog-foo3'", ) require.Equal(t, int64(3+(i*3)), s.metrics.RoutingErrCount.Count()) } @@ -742,51 +743,50 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { expectedTenantID uint64 expectedParams map[string]string expectedError string + expectedHint string }{ { name: "empty params", params: map[string]string{}, - expectedError: "missing cluster name in connection string", + expectedError: "missing cluster identifier", + expectedHint: clusterIdentifierHint, }, { - name: "cluster name is not provided", + name: "cluster identifier is not provided", params: map[string]string{ "database": "defaultdb", "options": "--foo=bar", }, - expectedError: "missing cluster name in connection string", + expectedError: "missing cluster identifier", + expectedHint: clusterIdentifierHint, }, { - name: "multiple similar cluster names", - params: map[string]string{ - "database": "happy-koala-7.defaultdb", - "options": "--cluster=happy-koala", - }, - expectedError: "multiple cluster names provided", - }, - { - name: "multiple different cluster names", + name: "multiple different cluster identifiers", params: map[string]string{ "database": "happy-koala-7.defaultdb", "options": "--cluster=happy-tiger", }, - expectedError: "multiple cluster names provided", + expectedError: "multiple different cluster identifiers provided", + expectedHint: "Is 'happy-koala-7' or 'happy-tiger' the identifier for the cluster that you're connecting to?\n--\n" + + clusterIdentifierHint, }, { - name: "invalid cluster name in database param", + name: "invalid cluster identifier in database param", params: map[string]string{ // Cluster names need to be between 6 to 20 alphanumeric characters. "database": "short-0.defaultdb", }, - expectedError: "invalid cluster name 'short-0'", + expectedError: "invalid cluster identifier 'short-0'", + expectedHint: "Is 'short' a valid cluster name?\n--\n" + clusterNameFormHint, }, { - name: "invalid cluster name in options param", + name: "invalid cluster identifier in options param", params: map[string]string{ // Cluster names need to be between 6 to 20 alphanumeric characters. "options": "--cluster=cockroachlabsdotcomfoobarbaz-0", }, - expectedError: "invalid cluster name 'cockroachlabsdotcomfoobarbaz-0'", + expectedError: "invalid cluster identifier 'cockroachlabsdotcomfoobarbaz-0'", + expectedHint: "Is 'cockroachlabsdotcomfoobarbaz' a valid cluster name?\n--\n" + clusterNameFormHint, }, { name: "invalid database param (1)", @@ -827,30 +827,45 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { { name: "no tenant id", params: map[string]string{"database": "happy2koala.defaultdb"}, - expectedError: "invalid cluster name 'happy2koala'", + expectedError: "invalid cluster identifier 'happy2koala'", + expectedHint: missingTenantIDHint + "\n--\n" + clusterNameFormHint, }, { name: "missing tenant id", params: map[string]string{"database": "happy2koala-.defaultdb"}, - expectedError: "invalid cluster name 'happy2koala-'", + expectedError: "invalid cluster identifier 'happy2koala-'", + expectedHint: missingTenantIDHint + "\n--\n" + clusterNameFormHint, }, { name: "missing cluster name", params: map[string]string{"database": "-7.defaultdb"}, - expectedError: "invalid cluster name '-7'", + expectedError: "invalid cluster identifier '-7'", + expectedHint: "Is '' a valid cluster name?\n--\n" + clusterNameFormHint, }, { name: "bad tenant id", params: map[string]string{"database": "happy-koala-0-7a.defaultdb"}, - expectedError: "invalid cluster name 'happy-koala-0-7a'", + expectedError: "invalid cluster identifier 'happy-koala-0-7a'", + expectedHint: "Is '7a' a valid tenant ID?\n--\n" + clusterNameFormHint, }, { name: "zero tenant id", params: map[string]string{"database": "happy-koala-0.defaultdb"}, - expectedError: "invalid cluster name 'happy-koala-0'", + expectedError: "invalid cluster identifier 'happy-koala-0'", + expectedHint: "Tenant ID 0 is invalid.", + }, + { + name: "multiple similar cluster identifiers", + params: map[string]string{ + "database": "happy-koala-7.defaultdb", + "options": "--cluster=happy-koala-7", + }, + expectedClusterName: "happy-koala", + expectedTenantID: 7, + expectedParams: map[string]string{"database": "defaultdb"}, }, { - name: "cluster name in database param", + name: "cluster identifier in database param", params: map[string]string{ "database": "happy-koala-7.defaultdb", "foo": "bar", @@ -860,7 +875,7 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { expectedParams: map[string]string{"database": "defaultdb", "foo": "bar"}, }, { - name: "valid cluster name with invalid arrangements", + name: "valid cluster identifier with invalid arrangements", params: map[string]string{ "database": "defaultdb", "options": "-c --cluster=happy-koala-7 -c -c -c", @@ -873,7 +888,7 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { }, }, { - name: "short option: cluster name in options param", + name: "short option: cluster identifier in options param", params: map[string]string{ "database": "defaultdb", "options": "-ccluster=happy-koala-7", @@ -883,7 +898,7 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { expectedParams: map[string]string{"database": "defaultdb"}, }, { - name: "short option with spaces: cluster name in options param", + name: "short option with spaces: cluster identifier in options param", params: map[string]string{ "database": "defaultdb", "options": "-c cluster=happy-koala-7", @@ -893,7 +908,7 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { expectedParams: map[string]string{"database": "defaultdb"}, }, { - name: "long option: cluster name in options param", + name: "long option: cluster identifier in options param", params: map[string]string{ "database": "defaultdb", "options": "--cluster=happy-koala-7\t--foo=test", @@ -906,7 +921,7 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { }, }, { - name: "long option: cluster name in options param with other options", + name: "long option: cluster identifier in options param with other options", params: map[string]string{ "database": "defaultdb", "options": "-csearch_path=public --cluster=happy-koala-7\t--foo=test", @@ -946,6 +961,9 @@ func TestClusterNameAndTenantFromParams(t *testing.T) { require.Equal(t, tc.expectedParams, outMsg.Parameters) } else { require.EqualErrorf(t, err, tc.expectedError, "failed test case\n%+v", tc) + + pgerr := pgerror.Flatten(err) + require.Equal(t, tc.expectedHint, pgerr.Hint) } // Check that the original parameters were not modified. diff --git a/pkg/cmd/dev/io/exec/exec.go b/pkg/cmd/dev/io/exec/exec.go index 4667119e3320..b2d10576e173 100644 --- a/pkg/cmd/dev/io/exec/exec.go +++ b/pkg/cmd/dev/io/exec/exec.go @@ -97,10 +97,24 @@ func (e *Exec) CommandContextWithInput( return e.commandContextImpl(ctx, r, false, name, args...) } +// CommandContextWithEnv is like CommandContextInheritingStdStreams, but +// accepting an additional argument for environment variables. +func (e *Exec) CommandContextWithEnv( + ctx context.Context, env []string, name string, args ...string, +) error { + return e.commandContextInheritingStdStreamsImpl(ctx, env, name, args...) +} + // CommandContextInheritingStdStreams is like CommandContext, but stdin, // stdout, and stderr are passed directly to the terminal. func (e *Exec) CommandContextInheritingStdStreams( ctx context.Context, name string, args ...string, +) error { + return e.commandContextInheritingStdStreamsImpl(ctx, nil, name, args...) +} + +func (e *Exec) commandContextInheritingStdStreamsImpl( + ctx context.Context, env []string, name string, args ...string, ) error { var command string if len(args) > 0 { @@ -117,6 +131,7 @@ func (e *Exec) CommandContextInheritingStdStreams( cmd.Stdout = e.stdout cmd.Stderr = e.stderr cmd.Dir = e.dir + cmd.Env = env if err := cmd.Start(); err != nil { return err diff --git a/pkg/cmd/dev/lint.go b/pkg/cmd/dev/lint.go index 477ac6a9533f..457dd52b9b29 100644 --- a/pkg/cmd/dev/lint.go +++ b/pkg/cmd/dev/lint.go @@ -12,6 +12,8 @@ package main import ( "fmt" + "os" + "strings" "github.com/spf13/cobra" ) @@ -23,8 +25,10 @@ func makeLintCmd(runE func(cmd *cobra.Command, args []string) error) *cobra.Comm Short: `Run the specified linters`, Long: `Run the specified linters.`, Example: ` - dev lint --filter=TestLowercaseFunctionNames --short --timeout=1m`, - Args: cobra.NoArgs, + dev lint --filter=TestLowercaseFunctionNames --short --timeout=1m + dev lint pkg/cmd/dev +`, + Args: cobra.MaximumNArgs(1), RunE: runE, } addCommonBuildFlags(lintCmd) @@ -32,7 +36,7 @@ func makeLintCmd(runE func(cmd *cobra.Command, args []string) error) *cobra.Comm return lintCmd } -func (d *dev) lint(cmd *cobra.Command, _ []string) error { +func (d *dev) lint(cmd *cobra.Command, pkgs []string) error { ctx := cmd.Context() filter := mustGetFlagString(cmd, filterFlag) timeout := mustGetFlagDuration(cmd, timeoutFlag) @@ -56,7 +60,15 @@ func (d *dev) lint(cmd *cobra.Command, _ []string) error { if filter != "" { args = append(args, "-test.run", fmt.Sprintf("Lint/%s", filter)) } - logCommand("bazel", args...) + if len(pkgs) > 0 { + pkg := strings.TrimRight(pkgs[0], "/") + if !strings.HasPrefix(pkg, "./") { + pkg = "./" + pkg + } + env := os.Environ() + env = append(env, fmt.Sprintf("PKG=%s", pkg)) + return d.exec.CommandContextWithEnv(ctx, env, "bazel", args...) + } return d.exec.CommandContextInheritingStdStreams(ctx, "bazel", args...) }