diff --git a/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go b/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go index 636abd8ae5d2..899679349e22 100644 --- a/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go +++ b/pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go @@ -77,16 +77,26 @@ func getGlobalReadsLead(clock *hlc.Clock) time.Duration { return clock.MaxOffset() } +// checkEnterpriseEnabled checks whether the enterprise feature for follower +// reads is enabled, returning a detailed error if not. It is not suitable for +// use in hot paths since a new error may be instantiated on each call. func checkEnterpriseEnabled(clusterID uuid.UUID, st *cluster.Settings) error { org := sql.ClusterOrganization.Get(&st.SV) return utilccl.CheckEnterpriseEnabled(st, clusterID, org, "follower reads") } +// isEnterpriseEnabled is faster than checkEnterpriseEnabled, and suitable +// for hot paths. +func isEnterpriseEnabled(clusterID uuid.UUID, st *cluster.Settings) bool { + org := sql.ClusterOrganization.Get(&st.SV) + return utilccl.IsEnterpriseEnabled(st, clusterID, org, "follower reads") +} + func checkFollowerReadsEnabled(clusterID uuid.UUID, st *cluster.Settings) bool { if !kvserver.FollowerReadsEnabled.Get(&st.SV) { return false } - return checkEnterpriseEnabled(clusterID, st) == nil + return isEnterpriseEnabled(clusterID, st) } func evalFollowerReadOffset(clusterID uuid.UUID, st *cluster.Settings) (time.Duration, error) { diff --git a/pkg/ccl/utilccl/license_check.go b/pkg/ccl/utilccl/license_check.go index b3e27afd149c..8d4e101367f8 100644 --- a/pkg/ccl/utilccl/license_check.go +++ b/pkg/ccl/utilccl/license_check.go @@ -57,6 +57,15 @@ const ( testingEnterpriseEnabled = 1 ) +// errEnterpriseRequired is returned by check() when the caller does +// not request detailed errors. +var errEnterpriseRequired = pgerror.New(pgcode.CCLValidLicenseRequired, + "a valid enterprise license is required") + +// licenseCacheKey is used to cache licenses in cluster.Settings.Cache, +// keeping the entries private. +type licenseCacheKey string + // TestingEnableEnterprise allows overriding the license check in tests. func TestingEnableEnterprise() func() { before := atomic.LoadInt32(&testingEnterprise) @@ -78,11 +87,21 @@ func TestingDisableEnterprise() func() { // CheckEnterpriseEnabled returns a non-nil error if the requested enterprise // feature is not enabled, including information or a link explaining how to // enable it. +// +// This should not be used in hot paths, since an unavailable feature will +// result in a new error being instantiated for every call -- use +// IsEnterpriseEnabled() instead. func CheckEnterpriseEnabled(st *cluster.Settings, cluster uuid.UUID, org, feature string) error { - if atomic.LoadInt32(&testingEnterprise) == testingEnterpriseEnabled { - return nil - } - return checkEnterpriseEnabledAt(st, timeutil.Now(), cluster, org, feature) + return checkEnterpriseEnabledAt(st, timeutil.Now(), cluster, org, feature, true /* withDetails */) +} + +// IsEnterpriseEnabled returns whether the requested enterprise feature is +// enabled. It is faster than CheckEnterpriseEnabled, since it does not return +// details about why the feature is unavailable, and can therefore be used in +// hot paths. +func IsEnterpriseEnabled(st *cluster.Settings, cluster uuid.UUID, org, feature string) bool { + return checkEnterpriseEnabledAt( + st, timeutil.Now(), cluster, org, feature, false /* withDetails */) == nil } func init() { @@ -97,47 +116,56 @@ func init() { func TimeToEnterpriseLicenseExpiry( ctx context.Context, st *cluster.Settings, asOf time.Time, ) (time.Duration, error) { - var lic *licenseccl.License - // FIXME(tschottdorf): see whether it makes sense to cache the decoded - // license. - if str := enterpriseLicense.Get(&st.SV); str != "" { - var err error - if lic, err = decode(str); err != nil { - return 0, err - } - } else { - return 0, nil + license, err := getLicense(st) + if err != nil || license == nil { + return 0, err } - expiration := timeutil.Unix(lic.ValidUntilUnixSec, 0) + expiration := timeutil.Unix(license.ValidUntilUnixSec, 0) return expiration.Sub(asOf), nil } func checkEnterpriseEnabledAt( - st *cluster.Settings, at time.Time, cluster uuid.UUID, org, feature string, + st *cluster.Settings, at time.Time, cluster uuid.UUID, org, feature string, withDetails bool, ) error { - var lic *licenseccl.License - // FIXME(tschottdorf): see whether it makes sense to cache the decoded - // license. - if str := enterpriseLicense.Get(&st.SV); str != "" { - var err error - if lic, err = decode(str); err != nil { - return err - } + if atomic.LoadInt32(&testingEnterprise) == testingEnterpriseEnabled { + return nil + } + license, err := getLicense(st) + if err != nil { + return err } - return check(lic, at, cluster, org, feature) + return check(license, at, cluster, org, feature, withDetails) } -func getLicenseType(st *cluster.Settings) (string, error) { +// getLicense fetches the license from the given settings, using Settings.Cache +// to cache the decoded license (if any). The returned license must not be +// modified by the caller. +func getLicense(st *cluster.Settings) (*licenseccl.License, error) { str := enterpriseLicense.Get(&st.SV) if str == "" { - return "None", nil + return nil, nil } - lic, err := decode(str) + cacheKey := licenseCacheKey(str) + if cachedLicense, ok := st.Cache.Load(cacheKey); ok { + return cachedLicense.(*licenseccl.License), nil + } + license, err := decode(str) + if err != nil { + return nil, err + } + st.Cache.Store(cacheKey, license) + return license, nil +} + +func getLicenseType(st *cluster.Settings) (string, error) { + license, err := getLicense(st) if err != nil { return "", err + } else if license == nil { + return "None", nil } - return lic.Type.String(), nil + return license.Type.String(), nil } // decode attempts to read a base64 encoded License. @@ -146,12 +174,18 @@ func decode(s string) (*licenseccl.License, error) { if err != nil { return nil, pgerror.WithCandidateCode(err, pgcode.Syntax) } - return lic, err + return lic, nil } -// check returns an error if the license is empty or not currently valid. -func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature string) error { +// check returns an error if the license is empty or not currently valid. If +// withDetails is false, a generic error message is returned for performance. +func check( + l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature string, withDetails bool, +) error { if l == nil { + if !withDetails { + return errEnterpriseRequired + } // TODO(dt): link to some stable URL that then redirects to a helpful page // that explains what to do here. link := "https://cockroachlabs.com/pricing?cluster=" @@ -168,6 +202,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature // suddenly throwing errors at them. if l.ValidUntilUnixSec > 0 && l.Type != licenseccl.License_Enterprise { if expiration := timeutil.Unix(l.ValidUntilUnixSec, 0); at.After(expiration) { + if !withDetails { + return errEnterpriseRequired + } licensePrefix := redact.SafeString("") switch l.Type { case licenseccl.License_NonCommercial: @@ -190,6 +227,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature if strings.EqualFold(l.OrganizationName, org) { return nil } + if !withDetails { + return errEnterpriseRequired + } return pgerror.Newf(pgcode.CCLValidLicenseRequired, "license valid only for %q", l.OrganizationName) } @@ -201,6 +241,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature } // no match, so compose an error message. + if !withDetails { + return errEnterpriseRequired + } var matches bytes.Buffer for i, c := range l.ClusterID { if i > 0 { diff --git a/pkg/ccl/utilccl/license_check_test.go b/pkg/ccl/utilccl/license_check_test.go index 206b6e6f250b..055bfe8f8227 100644 --- a/pkg/ccl/utilccl/license_check_test.go +++ b/pkg/ccl/utilccl/license_check_test.go @@ -63,7 +63,7 @@ func TestSettingAndCheckingLicense(t *testing.T) { if err := updater.Set("enterprise.license", tc.lic, "s"); err != nil { t.Fatal(err) } - err := checkEnterpriseEnabledAt(st, tc.checkTime, tc.checkCluster, "", "") + err := checkEnterpriseEnabledAt(st, tc.checkTime, tc.checkCluster, "", "", true) if !testutils.IsError(err, tc.err) { l, _ := decode(tc.lic) t.Fatalf("%d: lic %v, update by %T, checked by %s at %s, got %q", i, l, updater, tc.checkCluster, tc.checkTime, err) diff --git a/pkg/ccl/utilccl/license_test.go b/pkg/ccl/utilccl/license_test.go index 23f9441dd1d6..adc064a55d2b 100644 --- a/pkg/ccl/utilccl/license_test.go +++ b/pkg/ccl/utilccl/license_test.go @@ -83,7 +83,7 @@ func TestLicense(t *testing.T) { } } if err := check( - lic, tc.checkTime, tc.checkCluster, tc.checkOrg, "", + lic, tc.checkTime, tc.checkCluster, tc.checkOrg, "", true, ); !testutils.IsError(err, tc.err) { t.Fatalf("%d: lic for %s to %s, checked by %s at %s.\n got %q", i, tc.grantedTo, tc.expiration, tc.checkCluster, tc.checkTime, err) @@ -108,7 +108,7 @@ func TestExpiredLicenseLanguage(t *testing.T) { Type: licenseccl.License_Evaluation, ValidUntilUnixSec: 1, } - err := check(lic, timeutil.Now(), uuid.MakeV4(), "", "RESTORE") + err := check(lic, timeutil.Now(), uuid.MakeV4(), "", "RESTORE", true) expected := "Use of RESTORE requires an enterprise license. Your evaluation license expired on " + "January 1, 1970. If you're interested in getting a new license, please contact " + "subscriptions@cockroachlabs.com and we can help you out." diff --git a/pkg/settings/cluster/cluster_settings.go b/pkg/settings/cluster/cluster_settings.go index aa0cd7e7a238..13291e15d4d6 100644 --- a/pkg/settings/cluster/cluster_settings.go +++ b/pkg/settings/cluster/cluster_settings.go @@ -12,6 +12,7 @@ package cluster import ( "context" + "sync" "sync/atomic" "github.com/cockroachdb/cockroach/pkg/clusterversion" @@ -51,6 +52,10 @@ type Settings struct { // Setting the active cluster version has a very specific, intended usage // pattern. Look towards the interface itself for more commentary. Version clusterversion.Handle + + // Cache can be used for arbitrary caching, e.g. to cache decoded + // enterprises licenses for utilccl.CheckEnterpriseEnabled(). + Cache sync.Map } // TelemetryOptOut is a place for controlling whether to opt out of telemetry or not.