Skip to content

Commit

Permalink
utilccl,kvccl: add IsEnterpriseEnabled for faster license checks
Browse files Browse the repository at this point in the history
`utilccl.CheckEnterpriseEnabled()` is used to check whether a valid
enterprise license exists for a given feature. If no valid license is
found, it returns an error with specific details.

However, `kvccl` used this function in follower read hot paths, and
instantiating an error when follower reads are unavailable could have
significant overhead -- see e.g. #62447.

This patch adds `IsEnterpriseEnabled()`, which has the same behavior as
`CheckEnterpriseEnabled()` but returns a boolean instead. This is
significantly faster since we can avoid instantiating a custom error
each time. `kvccl` is also updated to use this in hot paths.

Release note: None
  • Loading branch information
erikgrinaker committed Mar 31, 2021
1 parent 99c29ef commit 0f424ea
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
12 changes: 11 additions & 1 deletion pkg/ccl/kvccl/kvfollowerreadsccl/followerreads.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,26 @@ func getGlobalReadsLead(clock *hlc.Clock) time.Duration {
return clock.MaxOffset()
}

// checkEnterpriseEnabled checks whether the enterprise feature for follower
// reads is enabled, returning a detailed error if not. It is not suitable for
// use in hot paths since a new error may be instantiated on each call.
func checkEnterpriseEnabled(clusterID uuid.UUID, st *cluster.Settings) error {
org := sql.ClusterOrganization.Get(&st.SV)
return utilccl.CheckEnterpriseEnabled(st, clusterID, org, "follower reads")
}

// isEnterpriseEnabled is faster than checkEnterpriseEnabled, and suitable
// for hot paths.
func isEnterpriseEnabled(clusterID uuid.UUID, st *cluster.Settings) bool {
org := sql.ClusterOrganization.Get(&st.SV)
return utilccl.IsEnterpriseEnabled(st, clusterID, org, "follower reads")
}

func checkFollowerReadsEnabled(clusterID uuid.UUID, st *cluster.Settings) bool {
if !kvserver.FollowerReadsEnabled.Get(&st.SV) {
return false
}
return checkEnterpriseEnabled(clusterID, st) == nil
return isEnterpriseEnabled(clusterID, st)
}

func evalFollowerReadOffset(clusterID uuid.UUID, st *cluster.Settings) (time.Duration, error) {
Expand Down
49 changes: 41 additions & 8 deletions pkg/ccl/utilccl/license_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ const (
testingEnterpriseEnabled = 1
)

// errEnterpriseRequired is returned by check() when the caller does
// not request detailed errors.
var errEnterpriseRequired = pgerror.New(pgcode.CCLValidLicenseRequired,
"a valid enterprise license is required")

// TestingEnableEnterprise allows overriding the license check in tests.
func TestingEnableEnterprise() func() {
before := atomic.LoadInt32(&testingEnterprise)
Expand All @@ -78,11 +83,21 @@ func TestingDisableEnterprise() func() {
// CheckEnterpriseEnabled returns a non-nil error if the requested enterprise
// feature is not enabled, including information or a link explaining how to
// enable it.
//
// This should not be used in hot paths, since an unavailable feature will
// result in a new error being instantiated for every call -- use
// IsEnterpriseEnabled() instead.
func CheckEnterpriseEnabled(st *cluster.Settings, cluster uuid.UUID, org, feature string) error {
if atomic.LoadInt32(&testingEnterprise) == testingEnterpriseEnabled {
return nil
}
return checkEnterpriseEnabledAt(st, timeutil.Now(), cluster, org, feature)
return checkEnterpriseEnabledAt(st, timeutil.Now(), cluster, org, feature, true /* withDetails */)
}

// IsEnterpriseEnabled returns whether the requested enterprise feature is
// enabled. It is faster than CheckEnterpriseEnabled, since it does not return
// details about why the feature is unavailable, and can therefore be used in
// hot paths.
func IsEnterpriseEnabled(st *cluster.Settings, cluster uuid.UUID, org, feature string) bool {
return checkEnterpriseEnabledAt(
st, timeutil.Now(), cluster, org, feature, false /* withDetails */) == nil
}

func init() {
Expand Down Expand Up @@ -114,8 +129,11 @@ func TimeToEnterpriseLicenseExpiry(
}

func checkEnterpriseEnabledAt(
st *cluster.Settings, at time.Time, cluster uuid.UUID, org, feature string,
st *cluster.Settings, at time.Time, cluster uuid.UUID, org, feature string, withDetails bool,
) error {
if atomic.LoadInt32(&testingEnterprise) == testingEnterpriseEnabled {
return nil
}
var lic *licenseccl.License
// FIXME(tschottdorf): see whether it makes sense to cache the decoded
// license.
Expand All @@ -125,7 +143,7 @@ func checkEnterpriseEnabledAt(
return err
}
}
return check(lic, at, cluster, org, feature)
return check(lic, at, cluster, org, feature, withDetails)
}

func getLicenseType(st *cluster.Settings) (string, error) {
Expand All @@ -149,9 +167,15 @@ func decode(s string) (*licenseccl.License, error) {
return lic, err
}

// check returns an error if the license is empty or not currently valid.
func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature string) error {
// check returns an error if the license is empty or not currently valid. If
// withDetails is false, a generic error message is returned for performance.
func check(
l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature string, withDetails bool,
) error {
if l == nil {
if !withDetails {
return errEnterpriseRequired
}
// TODO(dt): link to some stable URL that then redirects to a helpful page
// that explains what to do here.
link := "https://cockroachlabs.com/pricing?cluster="
Expand All @@ -168,6 +192,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature
// suddenly throwing errors at them.
if l.ValidUntilUnixSec > 0 && l.Type != licenseccl.License_Enterprise {
if expiration := timeutil.Unix(l.ValidUntilUnixSec, 0); at.After(expiration) {
if !withDetails {
return errEnterpriseRequired
}
licensePrefix := redact.SafeString("")
switch l.Type {
case licenseccl.License_NonCommercial:
Expand All @@ -190,6 +217,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature
if strings.EqualFold(l.OrganizationName, org) {
return nil
}
if !withDetails {
return errEnterpriseRequired
}
return pgerror.Newf(pgcode.CCLValidLicenseRequired,
"license valid only for %q", l.OrganizationName)
}
Expand All @@ -201,6 +231,9 @@ func check(l *licenseccl.License, at time.Time, cluster uuid.UUID, org, feature
}

// no match, so compose an error message.
if !withDetails {
return errEnterpriseRequired
}
var matches bytes.Buffer
for i, c := range l.ClusterID {
if i > 0 {
Expand Down

0 comments on commit 0f424ea

Please sign in to comment.