diff --git a/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go b/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go index 1bb28b69a566..11dc1a6c3578 100644 --- a/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go +++ b/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go @@ -77,16 +77,26 @@ func getGlobalReadsLead(clock *hlc.Clock) time.Duration { return clock.MaxOffset() } +// checkEnterpriseEnabled checks whether the enterprise feature for follower +// reads is enabled, returning a detailed error if not. It is not suitable for +// use in hot paths since a new error may be instantiated on each call. func checkEnterpriseEnabled(clusterID uuid.UUID, st *cluster.Settings) error { org := sql.ClusterOrganization.Get(&st.SV) return utilccl.CheckEnterpriseEnabled(st, clusterID, org, "follower reads") } +// isEnterpriseEnabled is faster than checkEnterpriseEnabled, and suitable +// for hot paths. +func isEnterpriseEnabled(clusterID uuid.UUID, st *cluster.Settings) bool { + org := sql.ClusterOrganization.Get(&st.SV) + return utilccl.IsEnterpriseEnabled(st, clusterID, org, "follower reads") +} + func checkFollowerReadsEnabled(clusterID uuid.UUID, st *cluster.Settings) bool { if !kvserver.FollowerReadsEnabled.Get(&st.SV) { return false } - return checkEnterpriseEnabled(clusterID, st) == nil + return isEnterpriseEnabled(clusterID, st) } func evalFollowerReadOffset(clusterID uuid.UUID, st *cluster.Settings) (time.Duration, error) { diff --git a/pkg/ccl/utilccl/license_check.go b/pkg/ccl/utilccl/license_check.go index b3e27afd149c..d2c494f7ac34 100644 --- a/pkg/ccl/utilccl/license_check.go +++ b/pkg/ccl/utilccl/license_check.go @@ -57,6 +57,11 @@ const ( testingEnterpriseEnabled = 1 ) +// errEnterpriseRequired is returned by check() when the caller does +// not request detailed errors. +var errEnterpriseRequired = pgerror.New(pgcode.CCLValidLicenseRequired, + "a valid enterprise license is required") + // TestingEnableEnterprise allows overriding the license check in tests. func TestingEnableEnterprise() func() { before := atomic.LoadInt32(&testingEnterprise) @@ -78,11 +83,21 @@ func TestingDisableEnterprise() func() { // CheckEnterpriseEnabled returns a non-nil error if the requested enterprise // feature is not enabled, including information or a link explaining how to // enable it. +// +// This should not be used in hot paths, since an unavailable feature will +// result in a new error being instantiated for every call -- use +// IsEnterpriseEnabled() instead. func CheckEnterpriseEnabled(st *cluster.Settings, cluster uuid.UUID, org, feature string) error { - if atomic.LoadInt32(&testingEnterprise) == testingEnterpriseEnabled { - return nil - } - return checkEnterpriseEnabledAt(st, timeutil.Now(), cluster, org, feature) + return checkEnterpriseEnabledAt(st, timeutil.Now(), cluster, org, feature, true /* withDetails */) +} + +// IsEnterpriseEnabled returns whether the requested enterprise feature is +// enabled. It is faster than CheckEnterpriseEnabled, since it does not return +// details about why the feature is unavailable, and can therefore be used in +// hot paths. +func IsEnterpriseEnabled(st *cluster.Settings, cluster uuid.UUID, org, feature string) bool { + return checkEnterpriseEnabledAt( + st, timeutil.Now(), cluster, org, feature, false /* withDetails */) == nil } func init() { @@ -114,8 +129,11 @@ func TimeToEnterpriseLicenseExpiry( } func checkEnterpriseEnabledAt( - st *cluster.Settings, at time.Time, cluster uuid.UUID, org, feature string, + st *cluster.Settings, at time.Time, cluster uuid.UUID, org, feature string, withDetails bool, ) error { + if atomic.LoadInt32(&testingEnterprise) == testingEnterpriseEnabled { + return nil + } var lic *licenseccl.License // FIXME(tschottdorf): see whether it makes sense to cache the decoded // license. @@ -125,7 +143,7 @@ func checkEnterpriseEnabledAt( return err } } - return check(lic, at, cluster, org, feature) + return check(lic, at, cluster, org, feature, withDetails) } func getLicenseType(st *cluster.Settings) (string, error) { @@ -149,9 +167,15 @@ func decode(s string) (*licenseccl.License, error) { return lic, err } -// check returns an error if the license is empty or not currently valid. -func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature string) error { +// check returns an error if the license is empty or not currently valid. If +// withDetails is false, a generic error message is returned for performance. +func check( + l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature string, withDetails bool, +) error { if l == nil { + if !withDetails { + return errEnterpriseRequired + } // TODO(dt): link to some stable URL that then redirects to a helpful page // that explains what to do here. link := "https://cockroachlabs.com/pricing?cluster=" @@ -168,6 +192,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature // suddenly throwing errors at them. if l.ValidUntilUnixSec > 0 && l.Type != licenseccl.License_Enterprise { if expiration := timeutil.Unix(l.ValidUntilUnixSec, 0); at.After(expiration) { + if !withDetails { + return errEnterpriseRequired + } licensePrefix := redact.SafeString("") switch l.Type { case licenseccl.License_NonCommercial: @@ -190,6 +217,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature if strings.EqualFold(l.OrganizationName, org) { return nil } + if !withDetails { + return errEnterpriseRequired + } return pgerror.Newf(pgcode.CCLValidLicenseRequired, "license valid only for %q", l.OrganizationName) } @@ -201,6 +231,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature } // no match, so compose an error message. + if !withDetails { + return errEnterpriseRequired + } var matches bytes.Buffer for i, c := range l.ClusterID { if i > 0 {