diff --git a/pkg/ccl/utilccl/BUILD.bazel b/pkg/ccl/utilccl/BUILD.bazel index 5947e49a4f4c..ca7cc8ffeabf 100644 --- a/pkg/ccl/utilccl/BUILD.bazel +++ b/pkg/ccl/utilccl/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", "//pkg/sql/types", + "//pkg/util/syncutil", "//pkg/util/timeutil", "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/ccl/utilccl/license_check.go b/pkg/ccl/utilccl/license_check.go index d2c494f7ac34..184ec5423d48 100644 --- a/pkg/ccl/utilccl/license_check.go +++ b/pkg/ccl/utilccl/license_check.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" @@ -62,6 +63,14 @@ const ( var errEnterpriseRequired = pgerror.New(pgcode.CCLValidLicenseRequired, "a valid enterprise license is required") +// decodeCache is used to cache licenses for decodeCached(). +var decodeCache = struct { + syncutil.RWMutex + licenses map[string]*licenseccl.License +}{ + licenses: map[string]*licenseccl.License{}, +} + // TestingEnableEnterprise allows overriding the license check in tests. func TestingEnableEnterprise() func() { before := atomic.LoadInt32(&testingEnterprise) @@ -113,11 +122,9 @@ func TimeToEnterpriseLicenseExpiry( ctx context.Context, st *cluster.Settings, asOf time.Time, ) (time.Duration, error) { var lic *licenseccl.License - // FIXME(tschottdorf): see whether it makes sense to cache the decoded - // license. if str := enterpriseLicense.Get(&st.SV); str != "" { var err error - if lic, err = decode(str); err != nil { + if lic, err = decodeCached(str); err != nil { return 0, err } } else { @@ -135,11 +142,9 @@ func checkEnterpriseEnabledAt( return nil } var lic *licenseccl.License - // FIXME(tschottdorf): see whether it makes sense to cache the decoded - // license. if str := enterpriseLicense.Get(&st.SV); str != "" { var err error - if lic, err = decode(str); err != nil { + if lic, err = decodeCached(str); err != nil { return err } } @@ -151,7 +156,7 @@ func getLicenseType(st *cluster.Settings) (string, error) { if str == "" { return "None", nil } - lic, err := decode(str) + lic, err := decodeCached(str) if err != nil { return "", err } @@ -164,7 +169,27 @@ func decode(s string) (*licenseccl.License, error) { if err != nil { return nil, pgerror.WithCandidateCode(err, pgcode.Syntax) } - return lic, err + return lic, nil +} + +// decodeCache decodes a base64-encoded License and caches the result. +func decodeCached(s string) (*licenseccl.License, error) { + decodeCache.RLock() + lic, ok := decodeCache.licenses[s] + decodeCache.RUnlock() + if ok { + return lic, nil + } + + lic, err := decode(s) + if err != nil { + return nil, err + } + + decodeCache.Lock() + decodeCache.licenses[s] = lic + decodeCache.Unlock() + return lic, nil } // check returns an error if the license is empty or not currently valid. If diff --git a/pkg/ccl/utilccl/license_check_test.go b/pkg/ccl/utilccl/license_check_test.go index 206b6e6f250b..055bfe8f8227 100644 --- a/pkg/ccl/utilccl/license_check_test.go +++ b/pkg/ccl/utilccl/license_check_test.go @@ -63,7 +63,7 @@ func TestSettingAndCheckingLicense(t *testing.T) { if err := updater.Set("enterprise.license", tc.lic, "s"); err != nil { t.Fatal(err) } - err := checkEnterpriseEnabledAt(st, tc.checkTime, tc.checkCluster, "", "") + err := checkEnterpriseEnabledAt(st, tc.checkTime, tc.checkCluster, "", "", true) if !testutils.IsError(err, tc.err) { l, _ := decode(tc.lic) t.Fatalf("%d: lic %v, update by %T, checked by %s at %s, got %q", i, l, updater, tc.checkCluster, tc.checkTime, err) diff --git a/pkg/ccl/utilccl/license_test.go b/pkg/ccl/utilccl/license_test.go index 23f9441dd1d6..adc064a55d2b 100644 --- a/pkg/ccl/utilccl/license_test.go +++ b/pkg/ccl/utilccl/license_test.go @@ -83,7 +83,7 @@ func TestLicense(t *testing.T) { } } if err := check( - lic, tc.checkTime, tc.checkCluster, tc.checkOrg, "", + lic, tc.checkTime, tc.checkCluster, tc.checkOrg, "", true, ); !testutils.IsError(err, tc.err) { t.Fatalf("%d: lic for %s to %s, checked by %s at %s.\n got %q", i, tc.grantedTo, tc.expiration, tc.checkCluster, tc.checkTime, err) @@ -108,7 +108,7 @@ func TestExpiredLicenseLanguage(t *testing.T) { Type: licenseccl.License_Evaluation, ValidUntilUnixSec: 1, } - err := check(lic, timeutil.Now(), uuid.MakeV4(), "", "RESTORE") + err := check(lic, timeutil.Now(), uuid.MakeV4(), "", "RESTORE", true) expected := "Use of RESTORE requires an enterprise license. Your evaluation license expired on " + "January 1, 1970. If you're interested in getting a new license, please contact " + "subscriptions@cockroachlabs.com and we can help you out."