diff --git a/auth/auth.go b/auth/auth.go index ab7f587fbe..caad6294c5 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -99,12 +99,12 @@ func NewAuthenticator(datastore base.DataStore, channelComputer ChannelComputer, } } -func DefaultAuthenticatorOptions() AuthenticatorOptions { +func DefaultAuthenticatorOptions(ctx context.Context) AuthenticatorOptions { return AuthenticatorOptions{ ClientPartitionWindow: base.DefaultClientPartitionWindow, SessionCookieName: DefaultCookieName, BcryptCost: DefaultBcryptCost, - LogCtx: context.Background(), + LogCtx: ctx, } } @@ -722,7 +722,7 @@ func (auth *Authenticator) AuthenticateUntrustedJWT(rawToken string, oidcProvide } if authenticator == nil { for _, provider := range oidcProviders { - if provider.ValidFor(issuer, audiences) { + if provider.ValidFor(auth.LogCtx, issuer, audiences) { base.TracefCtx(auth.LogCtx, base.KeyAuth, "Using OIDC provider %v", base.UD(provider.Issuer)) authenticator = provider break @@ -731,7 +731,7 @@ func (auth *Authenticator) AuthenticateUntrustedJWT(rawToken string, oidcProvide } if authenticator == nil { for _, provider := range localJWT { - if provider.ValidFor(issuer, audiences) { + if provider.ValidFor(auth.LogCtx, issuer, audiences) { base.TracefCtx(auth.LogCtx, base.KeyAuth, "Using local JWT provider %v", base.UD(provider.Issuer)) authenticator = provider break @@ -744,7 +744,7 @@ func (auth *Authenticator) AuthenticateUntrustedJWT(rawToken string, oidcProvide } var identity *Identity - identity, err = authenticator.verifyToken(context.TODO(), rawToken, callbackURLFunc) + identity, err = authenticator.verifyToken(auth.LogCtx, rawToken, callbackURLFunc) if err != nil { base.DebugfCtx(auth.LogCtx, base.KeyAuth, "JWT invalid: %v", err) return nil, PrincipalConfig{}, base.HTTPErrorf(http.StatusUnauthorized, "Invalid JWT") diff --git a/auth/auth_test.go b/auth/auth_test.go index 5d21e0b3a8..65c00a7b00 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -42,7 +42,7 @@ func TestValidateGuestUser(t *testing.T) { defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser("", "", nil) assert.True(t, user != nil) assert.True(t, err == nil) @@ -55,7 +55,7 @@ func TestValidateUser(t *testing.T) { dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser("invalid:name", "", nil) assert.Equal(t, user, (User)(nil)) assert.True(t, err != nil) @@ -74,7 +74,7 @@ func TestValidateRole(t *testing.T) { dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) role, err := auth.NewRole("invalid:name", nil) assert.Equal(t, (User)(nil), role) assert.True(t, err != nil) @@ -93,7 +93,7 @@ func TestValidateUserEmail(t *testing.T) { dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) badEmails := []string{"", "foo", "foo@", "@bar", "foo@bar@buzz"} for _, e := range badEmails { assert.False(t, IsValidEmail(e)) @@ -114,7 +114,7 @@ func TestUserPasswords(t *testing.T) { dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("me", "letmein", nil) assert.True(t, user.Authenticate("letmein")) assert.False(t, user.Authenticate("password")) @@ -139,7 +139,7 @@ func TestSerializeUser(t *testing.T) { dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("me", "letmein", ch.BaseSetOf(t, "me", "public")) require.NoError(t, user.SetEmail("foo@example.com")) encoded, _ := base.JSONMarshal(user) @@ -161,7 +161,7 @@ func TestSerializeRole(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) role, _ := auth.NewRole("froods", ch.BaseSetOf(t, "hoopy", "public")) encoded, _ := base.JSONMarshal(role) assert.True(t, encoded != nil) @@ -180,7 +180,7 @@ func TestUserAccess(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("foo", "password", nil) assert.Equal(t, ch.BaseSetOf(t, "!"), user.expandWildCardChannel(ch.BaseSetOf(t, "*"))) assert.False(t, user.canSeeChannel("x")) @@ -252,7 +252,7 @@ func TestGetMissingUser(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.GetUser("noSuchUser") assert.Equal(t, nil, err) assert.True(t, user == nil) @@ -266,7 +266,7 @@ func TestGetMissingRole(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) role, err := auth.GetRole("noSuchRole") assert.Equal(t, nil, err) assert.True(t, role == nil) @@ -276,7 +276,7 @@ func TestGetGuestUser(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.GetUser("") require.Equal(t, nil, err) assert.Equal(t, auth.defaultGuestUser(), user) @@ -287,7 +287,7 @@ func TestSaveUsers(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("testUser", "password", ch.BaseSetOf(t, "test")) err := auth.Save(user) assert.NoError(t, err) @@ -302,7 +302,7 @@ func TestSaveRoles(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) role, _ := auth.NewRole("testRole", ch.BaseSetOf(t, "test")) err := auth.Save(role) assert.Equal(t, nil, err) @@ -374,7 +374,7 @@ func TestRebuildUserChannels(t *testing.T) { dataStore := bucket.GetSingleDataStore() computer := mockComputer{} computer.AddChannelsForCollection(base.DefaultScope, base.DefaultCollection, ch.AtSequence(ch.BaseSetOf(t, "derived1", "derived2"), 1)) - auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("testUser", "password", ch.BaseSetOf(t, "explicit1")) err := auth.Save(user) assert.NoError(t, err) @@ -395,7 +395,7 @@ func TestRebuildUserChannelsMultiCollection(t *testing.T) { computer.AddChannelsForCollection(base.DefaultScope, base.DefaultCollection, ch.AtSequence(ch.BaseSetOf(t, "derived1", "derived2"), 1)) computer.AddChannelsForCollection("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "derived3", "derived4"), 1)) - options := DefaultAuthenticatorOptions() + options := DefaultAuthenticatorOptions(base.TestCtx(t)) options.Collections = map[string]map[string]struct{}{ base.DefaultScope: {base.DefaultCollection: struct{}{}}, "scope1": {"collection1": struct{}{}}, @@ -422,7 +422,7 @@ func TestRebuildUserChannelsNamedCollection(t *testing.T) { computer := mockComputer{} computer.AddChannelsForCollection("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "derived3", "derived4"), 1)) - options := DefaultAuthenticatorOptions() + options := DefaultAuthenticatorOptions(base.TestCtx(t)) options.Collections = map[string]map[string]struct{}{ "scope1": {"collection1": struct{}{}}, } @@ -454,7 +454,7 @@ func TestRebuildRoleChannels(t *testing.T) { dataStore := bucket.GetSingleDataStore() computer := mockComputer{} computer.AddRoleChannelsForCollection(base.DefaultScope, base.DefaultCollection, ch.AtSequence(ch.BaseSetOf(t, "derived1", "derived2"), 1)) - auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions(base.TestCtx(t))) role, err := auth.NewRole("testRole", ch.BaseSetOf(t, "explicit1")) assert.NoError(t, err) err = auth.Save(role) @@ -474,7 +474,7 @@ func TestRebuildChannelsError(t *testing.T) { defer bucket.Close() dataStore := bucket.GetSingleDataStore() computer := mockComputer{} - auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions(base.TestCtx(t))) role, err := auth.NewRole("testRole2", ch.BaseSetOf(t, "explicit1")) assert.NoError(t, err) err = auth.Save(role) @@ -495,7 +495,7 @@ func TestRebuildUserRoles(t *testing.T) { defer bucket.Close() dataStore := bucket.GetSingleDataStore() computer := mockComputer{roles: ch.AtSequence(base.SetOf("role1", "role2"), 3)} - auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &computer, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("testUser", "letmein", nil) user.SetExplicitRoles(ch.TimedSet{"role3": ch.NewVbSimpleSequence(1), "role1": ch.NewVbSimpleSequence(1)}, 1) err := auth.Save(user) @@ -524,7 +524,7 @@ func TestRoleInheritance(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) role, _ := auth.NewRole("square", ch.BaseSetOf(t, "dull", "duller", "dullest")) assert.Equal(t, nil, auth.Save(role)) role, _ = auth.NewRole("frood", ch.BaseSetOf(t, "hoopy", "hoopier", "hoopiest")) @@ -554,7 +554,7 @@ func TestRegisterUser(t *testing.T) { dataStore := bucket.GetSingleDataStore() // Register user based on name, email - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.RegisterNewUser("ValidName", "foo@example.com") require.NoError(t, err) assert.Equal(t, "ValidName", user.Name()) @@ -620,7 +620,7 @@ func TestCASUpdatePrincipal(t *testing.T) { email := "foo@bar.org" // Create user - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) // Modify the bcrypt cost to test rehashPassword properly below require.Error(t, auth.SetBcryptCost(5)) @@ -690,7 +690,7 @@ func TestConcurrentUserWrites(t *testing.T) { email := "foo@bar.org" // Create user - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) // Modify the bcrypt cost to test rehashPassword properly below require.Error(t, auth.SetBcryptCost(5)) @@ -778,7 +778,7 @@ func TestAuthenticateTrustedJWT(t *testing.T) { dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) ctx := base.TestCtx(t) @@ -1196,7 +1196,7 @@ func TestGetPrincipal(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) const ( channelRead = "read" @@ -1268,7 +1268,7 @@ func TestAuthenticateUntrustedJWT(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) issuerFacebookAccounts := "https://accounts.facebook.com" issuerAmazonAccounts := "https://accounts.amazon.com" @@ -1528,7 +1528,7 @@ func TestRevocationScenario1(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -1622,7 +1622,7 @@ func TestRevocationScenario2(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -1722,7 +1722,7 @@ func TestRevocationScenario3(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -1831,7 +1831,7 @@ func TestRevocationScenario4(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -1927,7 +1927,7 @@ func TestRevocationScenario5(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2007,7 +2007,7 @@ func TestRevocationScenario6(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2091,7 +2091,7 @@ func TestRevocationScenario7(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2172,7 +2172,7 @@ func TestRevocationScenario8(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2234,7 +2234,7 @@ func TestRevocationScenario9(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2293,7 +2293,7 @@ func TestRevocationScenario10(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2355,7 +2355,7 @@ func TestRevocationScenario11(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2423,7 +2423,7 @@ func TestRevocationScenario12(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2485,7 +2485,7 @@ func TestRevocationScenario13(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2545,7 +2545,7 @@ func TestRevocationScenario14(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) initializeScenario(t, auth) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 5) @@ -2582,7 +2582,7 @@ func TestRoleSoftDelete(t *testing.T) { dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) const roleName = "role" @@ -2696,7 +2696,7 @@ func TestObtainChannelsForDeletedRole(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) const roleName = "role" @@ -2734,7 +2734,7 @@ func TestInvalidateRoles(t *testing.T) { leakyDataStore, ok := base.AsLeakyDataStore(leakyBucket.DefaultDataStore()) require.True(t, ok) - auth := NewAuthenticator(leakyDataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(leakyDataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) // Invalidate role on non-existent user and ensure no error err := auth.InvalidateRoles("user", 0) @@ -2807,7 +2807,7 @@ func TestInvalidateChannels(t *testing.T) { leakyDataStore, ok := base.AsLeakyDataStore(leakyBucket.DefaultDataStore()) require.True(t, ok) - auth := NewAuthenticator(leakyDataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(leakyDataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) // Invalidate channels on non-existent user / role and ensure no error err := auth.InvalidateDefaultChannels(testCase.name, testCase.isUser, 0) diff --git a/auth/auth_time_sensitive_test.go b/auth/auth_time_sensitive_test.go index 788e55109d..d04000dea5 100644 --- a/auth/auth_time_sensitive_test.go +++ b/auth/auth_time_sensitive_test.go @@ -31,7 +31,7 @@ func TestAuthenticationSpeed(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("me", "goIsKewl", nil) assert.True(t, user.Authenticate("goIsKewl")) diff --git a/auth/collection_access_test.go b/auth/collection_access_test.go index 82010bd2e8..85fe464765 100644 --- a/auth/collection_access_test.go +++ b/auth/collection_access_test.go @@ -33,7 +33,7 @@ func TestUserCollectionAccess(t *testing.T) { // User with no access: bucket := base.GetTestBucket(t) defer bucket.Close() - options := DefaultAuthenticatorOptions() + options := DefaultAuthenticatorOptions(base.TestCtx(t)) options.Collections = map[string]map[string]struct{}{ "scope1": { "collection1": struct{}{}, @@ -170,7 +170,7 @@ func TestSerializeUserWithCollections(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() - auth := NewAuthenticator(bucket.GetSingleDataStore(), nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(bucket.GetSingleDataStore(), nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, _ := auth.NewUser("me", "letmein", ch.BaseSetOf(t, "me", "public")) encoded, err := base.JSONMarshal(user) require.NoError(t, err) diff --git a/auth/jwt.go b/auth/jwt.go index da8c610d40..f03a6004dc 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -59,13 +59,13 @@ type JWTConfigCommon struct { } // ValidFor returns whether the issuer matches, and one of the audiences matches -func (j JWTConfigCommon) ValidFor(issuer string, audiences audience) bool { +func (j JWTConfigCommon) ValidFor(ctx context.Context, issuer string, audiences audience) bool { if j.Issuer != issuer { return false } // Nil ClientID is invalid (checked by config validation), but empty-string disables audience checking if j.ClientID == nil { - base.ErrorfCtx(context.Background(), "JWTConfigCommon.ClientID nil - should never happen (for issuer %v)", base.UD(j.Issuer)) + base.ErrorfCtx(ctx, "JWTConfigCommon.ClientID nil - should never happen (for issuer %v)", base.UD(j.Issuer)) return false } if *j.ClientID == "" { @@ -147,7 +147,7 @@ type LocalJWTAuthConfig struct { } // BuildProvider prepares a LocalJWTAuthProvider from this config, initialising keySet. -func (l LocalJWTAuthConfig) BuildProvider(name string) *LocalJWTAuthProvider { +func (l LocalJWTAuthConfig) BuildProvider(ctx context.Context, name string) *LocalJWTAuthProvider { var prov *LocalJWTAuthProvider // validation ensures these are truly mutually exclusive if len(l.Keys) > 0 { @@ -160,10 +160,10 @@ func (l LocalJWTAuthConfig) BuildProvider(name string) *LocalJWTAuthProvider { prov = &LocalJWTAuthProvider{ LocalJWTAuthConfig: l, name: name, - keySet: oidc.NewRemoteKeySet(context.Background(), l.JWKSURI), + keySet: oidc.NewRemoteKeySet(ctx, l.JWKSURI), } } - prov.initUserPrefix() + prov.initUserPrefix(ctx) return prov } @@ -209,14 +209,14 @@ func (l *LocalJWTAuthProvider) common() JWTConfigCommon { return l.JWTConfigCommon } -func (l *LocalJWTAuthProvider) initUserPrefix() { +func (l *LocalJWTAuthProvider) initUserPrefix(ctx context.Context) { if l.UserPrefix != "" || l.UsernameClaim != "" { return } issuerURL, err := url.ParseRequestURI(l.Issuer) if err != nil { - base.WarnfCtx(context.TODO(), "Unable to parse issuer URI when initializing user prefix - using provider name") + base.WarnfCtx(ctx, "Unable to parse issuer URI when initializing user prefix - using provider name") l.UserPrefix = l.name return } diff --git a/auth/jwt_test.go b/auth/jwt_test.go index c11b5f7759..894b8fb6eb 100644 --- a/auth/jwt_test.go +++ b/auth/jwt_test.go @@ -86,7 +86,7 @@ func TestJWTVerifyToken(t *testing.T) { testIssuer = "testIssuer" testClientID = "testAud" ) - + ctx := base.TestCtx(t) common := JWTConfigCommon{ Issuer: testIssuer, ClientID: base.StringPtr(testClientID), @@ -96,13 +96,13 @@ func TestJWTVerifyToken(t *testing.T) { Algorithms: []string{"RS256", "ES256"}, Keys: []jose.JSONWebKey{testRSAJWK, testECJWK, testEncRSAJWK}, SkipExpiryCheck: base.BoolPtr(true), - }.BuildProvider("test") + }.BuildProvider(ctx, "test") providerWithExpiryCheck := LocalJWTAuthConfig{ JWTConfigCommon: common, Algorithms: []string{"RS256", "ES256"}, Keys: []jose.JSONWebKey{testRSAJWK, testECJWK, testEncRSAJWK}, SkipExpiryCheck: base.BoolPtr(false), - }.BuildProvider("test") + }.BuildProvider(ctx, "test") t.Run("garbage", test(baseProvider, "INVALID", anyError)) diff --git a/auth/main_test.go b/auth/main_test.go index 9928d8b369..df24c7d4e1 100644 --- a/auth/main_test.go +++ b/auth/main_test.go @@ -11,12 +11,14 @@ licenses/APL2.txt. package auth import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - base.TestBucketPoolNoIndexes(m, tbpOptions) + base.TestBucketPoolNoIndexes(ctx, m, tbpOptions) } diff --git a/auth/oidc.go b/auth/oidc.go index 083c595c88..dc6f35bd2f 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -540,7 +540,7 @@ func (op *OIDCProvider) runDiscoverySync(ctx context.Context, discoveryURL strin return ttl, err } if refresh && !op.isStandardDiscovery() { - verifier := op.generateVerifier(&metadata, context.Background()) + verifier := op.generateVerifier(&metadata, ctx) op.client.SetConfig(verifier, metadata.endpoint()) op.metadata = metadata } @@ -624,7 +624,7 @@ func (op *OIDCProvider) verifyToken(ctx context.Context, token string, callbackU } // Verify claims and signature on the JWT; ensure that it's been signed by the provider. - idToken, err := client.verifyJWT(token) + idToken, err := client.verifyJWT(ctx, token) if err != nil { base.DebugfCtx(ctx, base.KeyAuth, "Client %v could not verify JWT. Error: %v", base.UD(client), err) return nil, err @@ -661,10 +661,10 @@ func getIssuerWithAudience(token *jwt.JSONWebToken) (issuer string, audiences [] // verifyJWT parses a raw ID Token, verifies it's been signed by the provider // and returns the payload. It uses the ID Token Verifier to verify the token. -func (client *OIDCClient) verifyJWT(token string) (*oidc.IDToken, error) { +func (client *OIDCClient) verifyJWT(ctx context.Context, token string) (*oidc.IDToken, error) { client.mutex.RLock() defer client.mutex.RUnlock() - return client.verifier.Verify(context.Background(), token) + return client.verifier.Verify(ctx, token) } func SetURLQueryParam(strURL, name, value string) (string, error) { diff --git a/auth/oidc_test.go b/auth/oidc_test.go index 45c251d238..505374aef2 100644 --- a/auth/oidc_test.go +++ b/auth/oidc_test.go @@ -1191,7 +1191,7 @@ func TestJWTRolesChannels(t *testing.T) { roleChannels: map[string]ch.TimedSet{}, } - auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, &testMockComputer, DefaultAuthenticatorOptions(base.TestCtx(t))) provider := &OIDCProvider{ Name: "foo", diff --git a/auth/password_hash_test.go b/auth/password_hash_test.go index 5cd03de617..c3386c5a0e 100644 --- a/auth/password_hash_test.go +++ b/auth/password_hash_test.go @@ -62,7 +62,7 @@ func TestSetBcryptCost(t *testing.T) { bucket := base.GetTestBucket(t) defer bucket.Close() dataStore := bucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) err := auth.SetBcryptCost(DefaultBcryptCost - 1) // below minimum allowed value assert.Equal(t, ErrInvalidBcryptCost, errors.Cause(err)) diff --git a/auth/role_test.go b/auth/role_test.go index 270360d469..555a917632 100644 --- a/auth/role_test.go +++ b/auth/role_test.go @@ -45,7 +45,7 @@ func TestAuthorizeChannelsRole(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) role, err := auth.NewRole("root", channels.BaseSetOf(t, "superuser")) assert.NoError(t, err) @@ -65,9 +65,9 @@ func TestRoleKeysHash(t *testing.T) { defer testBucket.Close() dataStore := testBucket.DefaultDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) if !metadataDefault { - namedMetadataOptions := DefaultAuthenticatorOptions() + namedMetadataOptions := DefaultAuthenticatorOptions(base.TestCtx(t)) namedMetadataOptions.MetaKeys = base.NewMetadataKeys("foo") auth = NewAuthenticator(dataStore, nil, namedMetadataOptions) diff --git a/auth/session_test.go b/auth/session_test.go index a08ca35bb6..b4c54f9f17 100644 --- a/auth/session_test.go +++ b/auth/session_test.go @@ -29,7 +29,7 @@ func TestCreateSession(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser(username, "password", base.Set{}) require.NoError(t, err) @@ -74,7 +74,7 @@ func TestDeleteSession(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) id, err := base.GenerateRandomSecret() require.NoError(t, err) @@ -103,7 +103,7 @@ func TestMakeSessionCookie(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) sessionID, err := base.GenerateRandomSecret() require.NoError(t, err) @@ -129,7 +129,7 @@ func TestMakeSessionCookieProperties(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) sessionID, err := base.GenerateRandomSecret() require.NoError(t, err) @@ -164,7 +164,7 @@ func TestDeleteSessionForCookie(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) sessionID, err := base.GenerateRandomSecret() require.NoError(t, err) @@ -227,7 +227,7 @@ func TestCreateSessionChangePassword(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser(test.username, test.password, base.Set{}) require.NoError(t, err) @@ -266,7 +266,7 @@ func TestUserWithoutSessionUUID(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) const username = "Alice" user, err := auth.NewUser(username, "password", base.Set{}) require.NoError(t, err) diff --git a/auth/user_test.go b/auth/user_test.go index 579842da57..24b86137b5 100644 --- a/auth/user_test.go +++ b/auth/user_test.go @@ -35,7 +35,7 @@ func TestUserAuthenticateDisabled(t *testing.T) { defer bucket.Close() dataStore := bucket.GetSingleDataStore() // Create user - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) u, err := auth.NewUser(username, oldPassword, base.Set{}) assert.NoError(t, err) assert.NotNil(t, u) @@ -70,7 +70,7 @@ func TestUserAuthenticatePasswordHashUpgrade(t *testing.T) { defer bucket.Close() dataStore := bucket.GetSingleDataStore() // Create user - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) u, err := auth.NewUser(username, oldPassword, base.Set{}) require.NoError(t, err) require.NotNil(t, u) @@ -252,7 +252,7 @@ func TestCanSeeChannelSince(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) freeChannels := base.SetFromArray([]string{"ESPN", "HBO", "FX", "AMC"}) user, err := auth.NewUser("user", "password", freeChannels) assert.Nil(t, err) @@ -280,7 +280,7 @@ func TestGetAddedChannels(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) role, err := auth.NewRole("music", channels.BaseSetOf(t, "Spotify", "Youtube")) assert.Nil(t, err) @@ -323,7 +323,7 @@ func TestUserAuthenticateWithDisabledUserAccount(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser(username, password, base.Set{}) assert.NoError(t, err) @@ -345,7 +345,7 @@ func TestUserAuthenticateWithOldPasswordHash(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser(username, password, base.Set{}) assert.NoError(t, err) @@ -366,7 +366,7 @@ func TestUserAuthenticateWithBadPasswordHash(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser(username, password, base.Set{}) assert.NoError(t, err) @@ -387,7 +387,7 @@ func TestUserAuthenticateWithNoHashAndBadPassword(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) user, err := auth.NewUser(username, password, base.Set{}) assert.NoError(t, err) @@ -404,9 +404,9 @@ func TestUserKeysHash(t *testing.T) { defer testBucket.Close() dataStore := testBucket.GetSingleDataStore() - auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions()) + auth := NewAuthenticator(dataStore, nil, DefaultAuthenticatorOptions(base.TestCtx(t))) if !metadataDefault { - namedMetadataOptions := DefaultAuthenticatorOptions() + namedMetadataOptions := DefaultAuthenticatorOptions(base.TestCtx(t)) namedMetadataOptions.MetaKeys = base.NewMetadataKeys("foo") auth = NewAuthenticator(testBucket.GetSingleDataStore(), nil, namedMetadataOptions) diff --git a/base/bootstrap.go b/base/bootstrap.go index 0198a16ea5..7da3a978fe 100644 --- a/base/bootstrap.go +++ b/base/bootstrap.go @@ -29,20 +29,20 @@ type BootstrapConnection interface { // GetConfigBuckets returns a list of bucket names where a bootstrap metadata documents could reside. GetConfigBuckets() ([]string, error) // GetMetadataDocument fetches a bootstrap metadata document for a given bucket and key, along with the CAS of the config document. - GetMetadataDocument(bucket, key string, valuePtr interface{}) (cas uint64, err error) + GetMetadataDocument(ctx context.Context, bucket, key string, valuePtr interface{}) (cas uint64, err error) // InsertMetadataDocument saves a new bootstrap metadata document for a given bucket and key. - InsertMetadataDocument(bucket, key string, value interface{}) (newCAS uint64, err error) + InsertMetadataDocument(ctx context.Context, bucket, key string, value interface{}) (newCAS uint64, err error) // DeleteMetadataDocument deletes an existing bootstrap metadata document for a given bucket and key. - DeleteMetadataDocument(bucket, key string, cas uint64) (err error) + DeleteMetadataDocument(ctx context.Context, bucket, key string, cas uint64) (err error) // UpdateMetadataDocument updates an existing bootstrap metadata document for a given bucket and key. updateCallback can return nil to remove the config. Retries on CAS failure. - UpdateMetadataDocument(bucket, key string, updateCallback func(rawBucketConfig []byte, rawBucketConfigCas uint64) (updatedConfig []byte, err error)) (newCAS uint64, err error) + UpdateMetadataDocument(ctx context.Context, bucket, key string, updateCallback func(rawBucketConfig []byte, rawBucketConfigCas uint64) (updatedConfig []byte, err error)) (newCAS uint64, err error) // WriteMetadataDocument writes a bootstrap metadata document for a given bucket and key. Does not retry on CAS failure. - WriteMetadataDocument(bucket, key string, cas uint64, valuePtr interface{}) (casOut uint64, err error) + WriteMetadataDocument(ctx context.Context, bucket, key string, cas uint64, valuePtr interface{}) (casOut uint64, err error) // TouchMetadataDocument sets the specified property in a bootstrap metadata document for a given bucket and key. Used to // trigger CAS update on the document, to block any racing updates. Does not retry on CAS failure. - TouchMetadataDocument(bucket, key string, property string, value string, cas uint64) (casOut uint64, err error) + TouchMetadataDocument(ctx context.Context, bucket, key string, property string, value string, cas uint64) (casOut uint64, err error) // KeyExists checks whether the specified key exists - KeyExists(bucket, key string) (exists bool, err error) + KeyExists(ctx context.Context, bucket, key string) (exists bool, err error) // Returns the bootstrap connection's cluster connection as N1QLStore for the specified bucket/scope/collection. // Does NOT establish a bucket connection, the bucketName/scopeName/collectionName is for query scoping only GetClusterN1QLStore(bucketName, scopeName, collectionName string) (*ClusterOnlyN1QLStore, error) @@ -144,12 +144,12 @@ func (c *cachedBucketConnections) _set(bucketName string, bucket *cachedBucket) var _ BootstrapConnection = &CouchbaseCluster{} // NewCouchbaseCluster creates and opens a Couchbase Server cluster connection. -func NewCouchbaseCluster(server, username, password, +func NewCouchbaseCluster(ctx context.Context, server, username, password, x509CertPath, x509KeyPath, caCertPath string, forcePerBucketAuth bool, perBucketCreds PerBucketCredentialsConfig, tlsSkipVerify *bool, useXattrConfig *bool, bucketMode BucketConnectionMode) (*CouchbaseCluster, error) { - securityConfig, err := GoCBv2SecurityConfig(tlsSkipVerify, caCertPath) + securityConfig, err := GoCBv2SecurityConfig(ctx, tlsSkipVerify, caCertPath) if err != nil { return nil, err } @@ -311,12 +311,12 @@ func (cc *CouchbaseCluster) GetConfigBuckets() ([]string, error) { return bucketList, nil } -func (cc *CouchbaseCluster) GetMetadataDocument(location, docID string, valuePtr interface{}) (cas uint64, err error) { +func (cc *CouchbaseCluster) GetMetadataDocument(ctx context.Context, location, docID string, valuePtr interface{}) (cas uint64, err error) { if cc == nil { return 0, errors.New("nil CouchbaseCluster") } - b, teardown, err := cc.getBucket(location) + b, teardown, err := cc.getBucket(ctx, location) if err != nil { return 0, err @@ -324,15 +324,15 @@ func (cc *CouchbaseCluster) GetMetadataDocument(location, docID string, valuePtr defer teardown() - return cc.configPersistence.loadConfig(b.DefaultCollection(), docID, valuePtr) + return cc.configPersistence.loadConfig(ctx, b.DefaultCollection(), docID, valuePtr) } -func (cc *CouchbaseCluster) InsertMetadataDocument(location, key string, value interface{}) (newCAS uint64, err error) { +func (cc *CouchbaseCluster) InsertMetadataDocument(ctx context.Context, location, key string, value interface{}) (newCAS uint64, err error) { if cc == nil { return 0, errors.New("nil CouchbaseCluster") } - b, teardown, err := cc.getBucket(location) + b, teardown, err := cc.getBucket(ctx, location) if err != nil { return 0, err } @@ -342,12 +342,12 @@ func (cc *CouchbaseCluster) InsertMetadataDocument(location, key string, value i } // WriteMetadataDocument writes a metadata document, and fails on CAS mismatch -func (cc *CouchbaseCluster) WriteMetadataDocument(location, docID string, cas uint64, value interface{}) (newCAS uint64, err error) { +func (cc *CouchbaseCluster) WriteMetadataDocument(ctx context.Context, location, docID string, cas uint64, value interface{}) (newCAS uint64, err error) { if cc == nil { return 0, errors.New("nil CouchbaseCluster") } - b, teardown, err := cc.getBucket(location) + b, teardown, err := cc.getBucket(ctx, location) if err != nil { return 0, err } @@ -362,13 +362,13 @@ func (cc *CouchbaseCluster) WriteMetadataDocument(location, docID string, cas ui return uint64(casOut), err } -func (cc *CouchbaseCluster) TouchMetadataDocument(location, docID string, property, value string, cas uint64) (newCAS uint64, err error) { +func (cc *CouchbaseCluster) TouchMetadataDocument(ctx context.Context, location, docID string, property, value string, cas uint64) (newCAS uint64, err error) { if cc == nil { return 0, errors.New("nil CouchbaseCluster") } - b, teardown, err := cc.getBucket(location) + b, teardown, err := cc.getBucket(ctx, location) if err != nil { return 0, err } @@ -379,12 +379,12 @@ func (cc *CouchbaseCluster) TouchMetadataDocument(location, docID string, proper } -func (cc *CouchbaseCluster) DeleteMetadataDocument(location, key string, cas uint64) (err error) { +func (cc *CouchbaseCluster) DeleteMetadataDocument(ctx context.Context, location, key string, cas uint64) (err error) { if cc == nil { return errors.New("nil CouchbaseCluster") } - b, teardown, err := cc.getBucket(location) + b, teardown, err := cc.getBucket(ctx, location) if err != nil { return err } @@ -395,12 +395,12 @@ func (cc *CouchbaseCluster) DeleteMetadataDocument(location, key string, cas uin } // UpdateMetadataDocument retries on CAS mismatch -func (cc *CouchbaseCluster) UpdateMetadataDocument(location, docID string, updateCallback func(bucketConfig []byte, rawBucketConfigCas uint64) (newConfig []byte, err error)) (newCAS uint64, err error) { +func (cc *CouchbaseCluster) UpdateMetadataDocument(ctx context.Context, location, docID string, updateCallback func(bucketConfig []byte, rawBucketConfigCas uint64) (newConfig []byte, err error)) (newCAS uint64, err error) { if cc == nil { return 0, errors.New("nil CouchbaseCluster") } - b, teardown, err := cc.getBucket(location) + b, teardown, err := cc.getBucket(ctx, location) if err != nil { return 0, err } @@ -409,7 +409,7 @@ func (cc *CouchbaseCluster) UpdateMetadataDocument(location, docID string, updat collection := b.DefaultCollection() for { - bucketValue, cas, err := cc.configPersistence.loadRawConfig(collection, docID) + bucketValue, cas, err := cc.configPersistence.loadRawConfig(ctx, collection, docID) if err != nil { return 0, err } @@ -445,12 +445,12 @@ func (cc *CouchbaseCluster) UpdateMetadataDocument(location, docID string, updat } -func (cc *CouchbaseCluster) KeyExists(location, docID string) (exists bool, err error) { +func (cc *CouchbaseCluster) KeyExists(ctx context.Context, location, docID string) (exists bool, err error) { if cc == nil { return false, errors.New("nil CouchbaseCluster") } - b, teardown, err := cc.getBucket(location) + b, teardown, err := cc.getBucket(ctx, location) if err != nil { return false, err @@ -484,10 +484,10 @@ func (cc *CouchbaseCluster) GetClusterN1QLStore(bucketName, scopeName, collectio return NewClusterOnlyN1QLStore(gocbCluster, bucketName, scopeName, collectionName) } -func (cc *CouchbaseCluster) getBucket(bucketName string) (b *gocb.Bucket, teardownFn func(), err error) { +func (cc *CouchbaseCluster) getBucket(ctx context.Context, bucketName string) (b *gocb.Bucket, teardownFn func(), err error) { if cc.bucketConnectionMode != CachedClusterConnections { - return cc.connectToBucket(bucketName) + return cc.connectToBucket(ctx, bucketName) } teardownFn = func() { @@ -501,7 +501,7 @@ func (cc *CouchbaseCluster) getBucket(bucketName string) (b *gocb.Bucket, teardo } // cached bucket not found, connect and add - newBucket, bucketCloseFn, err := cc.connectToBucket(bucketName) + newBucket, bucketCloseFn, err := cc.connectToBucket(ctx, bucketName) if err != nil { return nil, nil, err } @@ -515,7 +515,7 @@ func (cc *CouchbaseCluster) getBucket(bucketName string) (b *gocb.Bucket, teardo } // connectToBucket establishes a new connection to a bucket, and returns the bucket after waiting for it to be ready. -func (cc *CouchbaseCluster) connectToBucket(bucketName string) (b *gocb.Bucket, teardownFn func(), err error) { +func (cc *CouchbaseCluster) connectToBucket(ctx context.Context, bucketName string) (b *gocb.Bucket, teardownFn func(), err error) { var connection *gocb.Cluster if bucketAuth, set := cc.perBucketAuth[bucketName]; set { connection, err = cc.connect(bucketAuth) @@ -548,7 +548,7 @@ func (cc *CouchbaseCluster) connectToBucket(bucketName string) (b *gocb.Bucket, teardownFn = func() { err := connection.Close(&gocb.ClusterCloseOptions{}) if err != nil { - WarnfCtx(context.Background(), "Failed to close cluster connection: %v", err) + WarnfCtx(ctx, "Failed to close cluster connection: %v", err) } } diff --git a/base/bootstrap_test.go b/base/bootstrap_test.go index 368b224501..10761a9344 100644 --- a/base/bootstrap_test.go +++ b/base/bootstrap_test.go @@ -48,8 +48,8 @@ func TestBootstrapRefCounting(t *testing.T) { forcePerBucketAuth := false tlsSkipVerify := BoolPtr(false) var perBucketCredentialsConfig map[string]*CredentialsConfig - - cluster, err := NewCouchbaseCluster(UnitTestUrl(), TestClusterUsername(), TestClusterPassword(), x509CertPath, x509KeyPath, caCertPath, forcePerBucketAuth, perBucketCredentialsConfig, tlsSkipVerify, BoolPtr(TestUseXattrs()), CachedClusterConnections) + ctx := TestCtx(t) + cluster, err := NewCouchbaseCluster(ctx, UnitTestUrl(), TestClusterUsername(), TestClusterPassword(), x509CertPath, x509KeyPath, caCertPath, forcePerBucketAuth, perBucketCredentialsConfig, tlsSkipVerify, BoolPtr(TestUseXattrs()), CachedClusterConnections) require.NoError(t, err) defer cluster.Close() require.NotNil(t, cluster) @@ -67,14 +67,14 @@ func TestBootstrapRefCounting(t *testing.T) { } } - require.Len(t, testBuckets, tbpNumBuckets()) + require.Len(t, testBuckets, tbpNumBuckets(ctx)) // GetConfigBuckets doesn't cache connections, it uses cluster connection to determine number of buckets require.Len(t, cluster.cachedBucketConnections.buckets, 0) primeBucketConnectionCache := func(bucketNames []string) { // Bucket CRUD ops do cache connections for _, bucketName := range bucketNames { - exists, err := cluster.KeyExists(bucketName, "keyThatDoesNotExist") + exists, err := cluster.KeyExists(ctx, bucketName, "keyThatDoesNotExist") require.NoError(t, err) require.False(t, exists) } @@ -102,7 +102,7 @@ func TestBootstrapRefCounting(t *testing.T) { makeConnection := make(chan struct{}) go func() { defer wg.Done() - b, teardown, err := cluster.getBucket(buckets[0]) + b, teardown, err := cluster.getBucket(ctx, buckets[0]) defer teardown() require.NoError(t, err) require.NotNil(t, b) diff --git a/base/bucket.go b/base/bucket.go index 85079dcfdd..9f3deb3372 100644 --- a/base/bucket.go +++ b/base/bucket.go @@ -58,9 +58,9 @@ type WrappingDatastore interface { type CouchbaseBucketStore interface { GetName() string MgmtEps() ([]string, error) - MetadataPurgeInterval() (time.Duration, error) - MaxTTL() (int, error) - HttpClient() *http.Client + MetadataPurgeInterval(ctx context.Context) (time.Duration, error) + MaxTTL(context.Context) (int, error) + HttpClient(context.Context) *http.Client GetSpec() BucketSpec GetMaxVbno() (uint16, error) @@ -69,7 +69,7 @@ type CouchbaseBucketStore interface { GetStatsVbSeqno(maxVbno uint16, useAbsHighSeqNo bool) (uuids map[uint16]uint64, highSeqnos map[uint16]uint64, seqErr error) // mgmtRequest uses the CouchbaseBucketStore's http client to make an http request against a management endpoint. - mgmtRequest(method, uri, contentType string, body io.Reader) (*http.Response, error) + mgmtRequest(ctx context.Context, method, uri, contentType string, body io.Reader) (*http.Response, error) } func AsCouchbaseBucketStore(b Bucket) (CouchbaseBucketStore, bool) { @@ -259,13 +259,13 @@ func (b BucketSpec) GetViewQueryTimeoutMs() uint64 { // TLSConfig creates a TLS configuration and populates the certificates // Errors will get logged then nil is returned. -func (b BucketSpec) TLSConfig() *tls.Config { +func (b BucketSpec) TLSConfig(ctx context.Context) *tls.Config { var certPool *x509.CertPool = nil if !b.TLSSkipVerify { // Add certs if ServerTLSSkipVerify is not set var err error - certPool, err = getRootCAs(b.CACertPath) + certPool, err = getRootCAs(ctx, b.CACertPath) if err != nil { - ErrorfCtx(context.Background(), "Error creating tlsConfig for DCP processing: %v", err) + ErrorfCtx(ctx, "Error creating tlsConfig for DCP processing: %v", err) return nil } } @@ -279,7 +279,7 @@ func (b BucketSpec) TLSConfig() *tls.Config { if b.Certpath != "" && b.Keypath != "" { cert, err := tls.LoadX509KeyPair(b.Certpath, b.Keypath) if err != nil { - ErrorfCtx(context.Background(), "Error creating tlsConfig for DCP processing: %v", err) + ErrorfCtx(ctx, "Error creating tlsConfig for DCP processing: %v", err) return nil } tlsConfig.Certificates = []tls.Certificate{cert} @@ -360,7 +360,7 @@ func GetBucket(ctx context.Context, spec BucketSpec) (bucket Bucket, err error) } InfofCtx(ctx, KeyAll, "Opening Couchbase database %s on <%s> as user %q", MD(spec.BucketName), SD(spec.Server), UD(username)) - bucket, err = GetGoCBv2Bucket(spec) + bucket, err = GetGoCBv2Bucket(ctx, spec) if err != nil { return nil, err } @@ -449,13 +449,13 @@ func GetFeedType(bucket Bucket) (feedType string) { // Gets the bucket max TTL, or 0 if no TTL was set. Sync gateway should fail to bring the DB online if this is non-zero, // since it's not meant to operate against buckets that auto-delete data. -func getMaxTTL(store CouchbaseBucketStore) (int, error) { +func getMaxTTL(ctx context.Context, store CouchbaseBucketStore) (int, error) { var bucketResponseWithMaxTTL struct { MaxTTLSeconds int `json:"maxTTL,omitempty"` } uri := fmt.Sprintf("/pools/default/buckets/%s", store.GetSpec().BucketName) - resp, err := store.mgmtRequest(http.MethodGet, uri, "application/json", nil) + resp, err := store.mgmtRequest(ctx, http.MethodGet, uri, "application/json", nil) if err != nil { return -1, err } @@ -475,8 +475,8 @@ func getMaxTTL(store CouchbaseBucketStore) (int, error) { } // Get the Server UUID of the bucket, this is also known as the Cluster UUID -func GetServerUUID(store CouchbaseBucketStore) (uuid string, err error) { - resp, err := store.mgmtRequest(http.MethodGet, "/pools", "application/json", nil) +func GetServerUUID(ctx context.Context, store CouchbaseBucketStore) (uuid string, err error) { + resp, err := store.mgmtRequest(ctx, http.MethodGet, "/pools", "application/json", nil) if err != nil { return "", err } @@ -501,18 +501,18 @@ func GetServerUUID(store CouchbaseBucketStore) (uuid string, err error) { // Gets the metadata purge interval for the bucket. First checks for a bucket-specific value. If not // found, retrieves the cluster-wide value. -func getMetadataPurgeInterval(store CouchbaseBucketStore) (time.Duration, error) { +func getMetadataPurgeInterval(ctx context.Context, store CouchbaseBucketStore) (time.Duration, error) { // Bucket-specific settings uri := fmt.Sprintf("/pools/default/buckets/%s", store.GetName()) - bucketPurgeInterval, err := retrievePurgeInterval(store, uri) + bucketPurgeInterval, err := retrievePurgeInterval(ctx, store, uri) if bucketPurgeInterval > 0 || err != nil { return bucketPurgeInterval, err } // Cluster-wide settings uri = fmt.Sprintf("/settings/autoCompaction") - clusterPurgeInterval, err := retrievePurgeInterval(store, uri) + clusterPurgeInterval, err := retrievePurgeInterval(ctx, store, uri) if clusterPurgeInterval > 0 || err != nil { return clusterPurgeInterval, err } @@ -524,14 +524,14 @@ func getMetadataPurgeInterval(store CouchbaseBucketStore) (time.Duration, error) // Helper function to retrieve a Metadata Purge Interval from server and convert to hours. Works for any uri // that returns 'purgeInterval' as a root-level property (which includes the two server endpoints for // bucket and server purge intervals). -func retrievePurgeInterval(bucket CouchbaseBucketStore, uri string) (time.Duration, error) { +func retrievePurgeInterval(ctx context.Context, bucket CouchbaseBucketStore, uri string) (time.Duration, error) { // Both of the purge interval endpoints (cluster and bucket) return purgeInterval in the same way var purgeResponse struct { PurgeInterval float64 `json:"purgeInterval,omitempty"` } - resp, err := bucket.mgmtRequest(http.MethodGet, uri, "application/json", nil) + resp, err := bucket.mgmtRequest(ctx, http.MethodGet, uri, "application/json", nil) if err != nil { return 0, err } @@ -539,7 +539,7 @@ func retrievePurgeInterval(bucket CouchbaseBucketStore, uri string) (time.Durati defer func() { _ = resp.Body.Close() }() if resp.StatusCode == http.StatusForbidden { - WarnfCtx(context.TODO(), "403 Forbidden attempting to access %s. Bucket user must have Bucket Full Access and Bucket Admin roles to retrieve metadata purge interval.", UD(uri)) + WarnfCtx(ctx, "403 Forbidden attempting to access %s. Bucket user must have Bucket Full Access and Bucket Admin roles to retrieve metadata purge interval.", UD(uri)) } else if resp.StatusCode != http.StatusOK { return 0, errors.New(resp.Status) } @@ -558,10 +558,10 @@ func retrievePurgeInterval(bucket CouchbaseBucketStore, uri string) (time.Durati return time.Duration(purgeIntervalHours) * time.Hour, nil } -func ensureBodyClosed(body io.ReadCloser) { +func ensureBodyClosed(ctx context.Context, body io.ReadCloser) { err := body.Close() if err != nil { - DebugfCtx(context.TODO(), KeyBucket, "Failed to close socket: %v", err) + DebugfCtx(ctx, KeyBucket, "Failed to close socket: %v", err) } } @@ -580,22 +580,22 @@ func AsSubdocStore(ds DataStore) (sgbucket.SubdocStore, bool) { // WaitUntilDataStoreExists will try to perform an operation in the given DataStore until it can succeed. // // There's no WaitForReady operation in GoCB for collections, only Buckets, so attempting to use Exists in this way this seems like our best option to check for availability. -func WaitUntilDataStoreExists(ds DataStore) error { - return WaitForNoError(func() error { +func WaitUntilDataStoreExists(ctx context.Context, ds DataStore) error { + return WaitForNoError(ctx, func() error { _, err := ds.Exists("WaitUntilDataStoreExists") return err }) } // RequireNoBucketTTL ensures there is no MaxTTL set on the bucket (SG #3314) -func RequireNoBucketTTL(b Bucket) error { +func RequireNoBucketTTL(ctx context.Context, b Bucket) error { cbs, ok := AsCouchbaseBucketStore(b) if !ok { // Not a Couchbase bucket - no TTL check to do return nil } - maxTTL, err := cbs.MaxTTL() + maxTTL, err := cbs.MaxTTL(ctx) if err != nil { return err } diff --git a/base/bucket_gocb.go b/base/bucket_gocb.go index b5313a9ba3..b15402c5cb 100644 --- a/base/bucket_gocb.go +++ b/base/bucket_gocb.go @@ -67,7 +67,7 @@ func isGoCBQueryTimeoutError(err error) bool { } // putDDocForTombstones uses the provided client and endpoints to create a design doc with index_xattr_on_deleted_docs=true -func putDDocForTombstones(name string, payload []byte, capiEps []string, client *http.Client, username string, password string) error { +func putDDocForTombstones(ctx context.Context, name string, payload []byte, capiEps []string, client *http.Client, username string, password string) error { // From gocb.Bucket.getViewEp() - pick view endpoint at random if len(capiEps) == 0 { @@ -92,7 +92,7 @@ func putDDocForTombstones(name string, payload []byte, capiEps []string, client return err } - defer ensureBodyClosed(resp.Body) + defer ensureBodyClosed(ctx, resp.Body) if resp.StatusCode != 201 { data, err := io.ReadAll(resp.Body) if err != nil { @@ -149,20 +149,20 @@ func normalizeIntToUint(value interface{}) (uint, error) { } } -func asBool(value interface{}) bool { +func asBool(ctx context.Context, value interface{}) bool { switch typeValue := value.(type) { case string: parsedVal, err := strconv.ParseBool(typeValue) if err != nil { - WarnfCtx(context.Background(), "asBool called with unknown value: %v. defaulting to false", typeValue) + WarnfCtx(ctx, "asBool called with unknown value: %v. defaulting to false", typeValue) return false } return parsedVal case bool: return typeValue default: - WarnfCtx(context.Background(), "asBool called with unknown type: %T. defaulting to false", typeValue) + WarnfCtx(ctx, "asBool called with unknown type: %T. defaulting to false", typeValue) return false } diff --git a/base/bucket_gocb_test.go b/base/bucket_gocb_test.go index 506ee224a1..85bffb2102 100644 --- a/base/bucket_gocb_test.go +++ b/base/bucket_gocb_test.go @@ -1416,7 +1416,8 @@ func TestDeleteWithXattrWithSimulatedRaceResurrect(t *testing.T) { // case to KvXattrStore to pass to deleteWithXattrInternal collection, ok := dataStore.(*Collection) require.True(t, ok) - deleteErr := deleteWithXattrInternal(collection, key, xattrName, callback) + ctx := TestCtx(t) + deleteErr := deleteWithXattrInternal(ctx, collection, key, xattrName, callback) assert.Equal(t, 1, numTimesCalledBack) assert.True(t, deleteErr != nil, "We expected an error here, because deleteWithXattrInternal should have "+ " detected that the doc was resurrected during its execution") @@ -1823,9 +1824,9 @@ func TestApplyViewQueryOptions(t *testing.T) { ViewQueryParamKey: "hello", ViewQueryParamKeys: []interface{}{"a", "b"}, } - + ctx := TestCtx(t) // Call applyViewQueryOptions (method being tested) which modifies viewQuery according to params - viewOpts, err := createViewOptions(params) + viewOpts, err := createViewOptions(ctx, params) if err != nil { t.Fatalf("Error calling applyViewQueryOptions: %v", err) } @@ -1899,7 +1900,7 @@ func TestApplyViewQueryOptionsWithStrings(t *testing.T) { ViewQueryParamKeys: []string{"a", "b"}, } - _, err := createViewOptions(params) + _, err := createViewOptions(TestCtx(t), params) if err != nil { t.Fatalf("Error calling applyViewQueryOptions: %v", err) } @@ -1916,8 +1917,9 @@ func TestApplyViewQueryStaleOptions(t *testing.T) { ViewQueryParamStale: "false", } + ctx := TestCtx(t) // if it doesn't blow up, test passes - if _, err := createViewOptions(params); err != nil { + if _, err := createViewOptions(ctx, params); err != nil { t.Fatalf("Error calling applyViewQueryOptions: %v", err) } @@ -1925,7 +1927,7 @@ func TestApplyViewQueryStaleOptions(t *testing.T) { ViewQueryParamStale: "ok", } - if _, err := createViewOptions(params); err != nil { + if _, err := createViewOptions(ctx, params); err != nil { t.Fatalf("Error calling applyViewQueryOptions: %v", err) } @@ -1941,7 +1943,7 @@ func TestCouchbaseServerMaxTTL(t *testing.T) { cbStore, ok := AsCouchbaseBucketStore(bucket) require.True(t, ok) - maxTTL, err := cbStore.MaxTTL() + maxTTL, err := cbStore.MaxTTL(TestCtx(t)) assert.NoError(t, err, "Unexpected error") assert.Equal(t, 0, maxTTL) diff --git a/base/bucket_test.go b/base/bucket_test.go index b7004cc4f5..32e0402777 100644 --- a/base/bucket_test.go +++ b/base/bucket_test.go @@ -397,12 +397,13 @@ func TestTLSConfig(t *testing.T) { Keypath: "/var/lib/couchbase/unknown.client.key", CACertPath: "/var/lib/couchbase/unknown.root.ca.pem", } - conf := spec.TLSConfig() + ctx := TestCtx(t) + conf := spec.TLSConfig(ctx) assert.Nil(t, conf) // Simulate valid configuration scenario with fake mocked certificates and keys; spec = BucketSpec{Certpath: clientCertPath, Keypath: clientKeyPath, CACertPath: rootCertPath} - conf = spec.TLSConfig() + conf = spec.TLSConfig(ctx) assert.NotEmpty(t, conf) assert.NotNil(t, conf.RootCAs) require.Len(t, conf.Certificates, 1) @@ -410,7 +411,7 @@ func TestTLSConfig(t *testing.T) { // Check TLSConfig with no CA certificate, and TlsSkipVerify true; InsecureSkipVerify should be true spec = BucketSpec{TLSSkipVerify: true, Certpath: clientCertPath, Keypath: clientKeyPath} - conf = spec.TLSConfig() + conf = spec.TLSConfig(ctx) assert.NotEmpty(t, conf) assert.True(t, conf.InsecureSkipVerify) require.Len(t, conf.Certificates, 1) @@ -418,7 +419,7 @@ func TestTLSConfig(t *testing.T) { // Check TLSConfig with no certificates provided, and TlsSkipVerify true. InsecureSkipVerify should be true and fields should be nil CBG-1518 spec = BucketSpec{TLSSkipVerify: true} - conf = spec.TLSConfig() + conf = spec.TLSConfig(ctx) assert.NotEmpty(t, conf) assert.True(t, conf.InsecureSkipVerify) assert.Nil(t, conf.RootCAs) @@ -426,7 +427,7 @@ func TestTLSConfig(t *testing.T) { // Check TLSConfig with no certs provided. InsecureSkipVerify should always be false. Should be empty config on Windows CBG-1518 spec = BucketSpec{} - conf = spec.TLSConfig() + conf = spec.TLSConfig(ctx) assert.NotEmpty(t, conf) assert.False(t, conf.InsecureSkipVerify) require.NotNil(t, conf.RootCAs) @@ -435,13 +436,13 @@ func TestTLSConfig(t *testing.T) { // Check TLSConfig by providing invalid root CA certificate; provide root certificate key path // instead of root CA certificate. It should throw "can't append certs from PEM" error. spec = BucketSpec{Certpath: clientCertPath, Keypath: clientKeyPath, CACertPath: rootKeyPath} - conf = spec.TLSConfig() + conf = spec.TLSConfig(ctx) assert.Empty(t, conf) // Provide invalid client certificate key along with valid certificate; It should fail while // trying to add key and certificate to config as x509 key pair; spec = BucketSpec{Certpath: clientCertPath, Keypath: rootKeyPath, CACertPath: rootCertPath} - conf = spec.TLSConfig() + conf = spec.TLSConfig(ctx) assert.Empty(t, conf) } diff --git a/base/bucket_view_test.go b/base/bucket_view_test.go index 56452a6f91..ff827ee387 100644 --- a/base/bucket_view_test.go +++ b/base/bucket_view_test.go @@ -78,7 +78,8 @@ func TestView(t *testing.T) { description := fmt.Sprintf("Wait for view readiness") sleeper := CreateSleeperFunc(50, 100) - viewErr, _ := RetryLoop(description, worker, sleeper) + ctx := TestCtx(t) + viewErr, _ := RetryLoop(ctx, description, worker, sleeper) require.NoError(t, viewErr) // stale=false diff --git a/base/collection.go b/base/collection.go index 97e122d3dd..e2b546a6f2 100644 --- a/base/collection.go +++ b/base/collection.go @@ -29,16 +29,15 @@ import ( ) // GetGoCBv2Bucket opens a connection to the Couchbase cluster and returns a *GocbV2Bucket for the specified BucketSpec. -func GetGoCBv2Bucket(spec BucketSpec) (*GocbV2Bucket, error) { +func GetGoCBv2Bucket(ctx context.Context, spec BucketSpec) (*GocbV2Bucket, error) { - logCtx := context.TODO() connString, err := spec.GetGoCBConnString(nil) if err != nil { - WarnfCtx(logCtx, "Unable to parse server value: %s error: %v", SD(spec.Server), err) + WarnfCtx(ctx, "Unable to parse server value: %s error: %v", SD(spec.Server), err) return nil, err } - securityConfig, err := GoCBv2SecurityConfig(&spec.TLSSkipVerify, spec.CACertPath) + securityConfig, err := GoCBv2SecurityConfig(ctx, &spec.TLSSkipVerify, spec.CACertPath) if err != nil { return nil, err } @@ -49,13 +48,13 @@ func GetGoCBv2Bucket(spec BucketSpec) (*GocbV2Bucket, error) { } if _, ok := authenticator.(gocb.CertificateAuthenticator); ok { - InfofCtx(logCtx, KeyAuth, "Using cert authentication for bucket %s on %s", MD(spec.BucketName), MD(spec.Server)) + InfofCtx(ctx, KeyAuth, "Using cert authentication for bucket %s on %s", MD(spec.BucketName), MD(spec.Server)) } else { - InfofCtx(logCtx, KeyAuth, "Using credential authentication for bucket %s on %s", MD(spec.BucketName), MD(spec.Server)) + InfofCtx(ctx, KeyAuth, "Using credential authentication for bucket %s on %s", MD(spec.BucketName), MD(spec.Server)) } timeoutsConfig := GoCBv2TimeoutsConfig(spec.BucketOpTimeout, StdlibDurationPtr(spec.GetViewQueryTimeout())) - InfofCtx(logCtx, KeyAll, "Setting query timeouts for bucket %s to %v", spec.BucketName, timeoutsConfig.QueryTimeout) + InfofCtx(ctx, KeyAll, "Setting query timeouts for bucket %s to %v", spec.BucketName, timeoutsConfig.QueryTimeout) clusterOptions := gocb.ClusterOptions{ Authenticator: authenticator, @@ -70,7 +69,7 @@ func GetGoCBv2Bucket(spec BucketSpec) (*GocbV2Bucket, error) { cluster, err := gocb.Connect(connString, clusterOptions) if err != nil { - InfofCtx(logCtx, KeyAuth, "Unable to connect to cluster: %v", err) + InfofCtx(ctx, KeyAuth, "Unable to connect to cluster: %v", err) return nil, err } @@ -85,11 +84,11 @@ func GetGoCBv2Bucket(spec BucketSpec) (*GocbV2Bucket, error) { if errors.Is(err, gocb.ErrAuthenticationFailure) { return nil, ErrAuthError } - WarnfCtx(context.TODO(), "Error waiting for cluster to be ready: %v", err) + WarnfCtx(ctx, "Error waiting for cluster to be ready: %v", err) return nil, err } - return GetGocbV2BucketFromCluster(cluster, spec, time.Second*30, true) + return GetGocbV2BucketFromCluster(ctx, cluster, spec, time.Second*30, true) } @@ -105,7 +104,7 @@ func getClusterVersion(cluster *gocb.Cluster) (int, int, error) { return clusterCompatMajor, clusterCompatMinor, nil } -func GetGocbV2BucketFromCluster(cluster *gocb.Cluster, spec BucketSpec, waitUntilReady time.Duration, failFast bool) (*GocbV2Bucket, error) { +func GetGocbV2BucketFromCluster(ctx context.Context, cluster *gocb.Cluster, spec BucketSpec, waitUntilReady time.Duration, failFast bool) (*GocbV2Bucket, error) { // Connect to bucket bucket := cluster.Bucket(spec.BucketName) @@ -124,7 +123,7 @@ func GetGocbV2BucketFromCluster(cluster *gocb.Cluster, spec BucketSpec, waitUnti if errors.Is(err, gocb.ErrAuthenticationFailure) { return nil, ErrAuthError } - WarnfCtx(context.TODO(), "Error waiting for bucket to be ready: %v", err) + WarnfCtx(ctx, "Error waiting for bucket to be ready: %v", err) return nil, err } clusterCompatMajor, clusterCompatMinor, err := getClusterVersion(cluster) @@ -154,7 +153,7 @@ func GetGocbV2BucketFromCluster(cluster *gocb.Cluster, spec BucketSpec, waitUnti if maxConcurrentQueryOps > DefaultHttpMaxIdleConnsPerHost*queryNodeCount { maxConcurrentQueryOps = DefaultHttpMaxIdleConnsPerHost * queryNodeCount - InfofCtx(context.TODO(), KeyAll, "Setting max_concurrent_query_ops to %d based on query node count (%d)", maxConcurrentQueryOps, queryNodeCount) + InfofCtx(ctx, KeyAll, "Setting max_concurrent_query_ops to %d based on query node count (%d)", maxConcurrentQueryOps, queryNodeCount) } gocbv2Bucket.queryOps = make(chan struct{}, maxConcurrentQueryOps) @@ -186,8 +185,9 @@ type GocbV2Bucket struct { } var ( - _ sgbucket.BucketStore = &GocbV2Bucket{} - _ CouchbaseBucketStore = &GocbV2Bucket{} + _ sgbucket.BucketStore = &GocbV2Bucket{} + _ CouchbaseBucketStore = &GocbV2Bucket{} + _ sgbucket.DynamicDataStoreBucket = &GocbV2Bucket{} ) func AsGocbV2Bucket(bucket Bucket) (*GocbV2Bucket, error) { @@ -352,14 +352,14 @@ func (b *GocbV2Bucket) Flush(ctx context.Context) error { workerFlush := func() (shouldRetry bool, err error, value interface{}) { if err := bucketManager.FlushBucket(b.GetName(), nil); err != nil { - WarnfCtx(context.TODO(), "Error flushing bucket %s: %v Will retry.", MD(b.GetName()).Redact(), err) + WarnfCtx(ctx, "Error flushing bucket %s: %v Will retry.", MD(b.GetName()).Redact(), err) return true, err, nil } return false, nil, nil } - err, _ := RetryLoop("EmptyTestBucket", workerFlush, CreateDoublingSleeperFunc(12, 10)) + err, _ := RetryLoop(ctx, "EmptyTestBucket", workerFlush, CreateDoublingSleeperFunc(12, 10)) if err != nil { return err } @@ -382,7 +382,7 @@ func (b *GocbV2Bucket) Flush(ctx context.Context) error { } // Kick off retry loop - err, _ = RetryLoop("Wait until bucket has 0 items after flush", worker, CreateMaxDoublingSleeperFunc(25, 100, 10000)) + err, _ = RetryLoop(ctx, "Wait until bucket has 0 items after flush", worker, CreateMaxDoublingSleeperFunc(25, 100, 10000)) if err != nil { return pkgerrors.Wrapf(err, "Error during Wait until bucket %s has 0 items after flush", MD(b.GetName()).Redact()) } @@ -443,18 +443,18 @@ func (b *GocbV2Bucket) QueryEpsCount() (int, error) { // Gets the metadata purge interval for the bucket. First checks for a bucket-specific value. If not // found, retrieves the cluster-wide value. -func (b *GocbV2Bucket) MetadataPurgeInterval() (time.Duration, error) { - return getMetadataPurgeInterval(b) +func (b *GocbV2Bucket) MetadataPurgeInterval(ctx context.Context) (time.Duration, error) { + return getMetadataPurgeInterval(ctx, b) } -func (b *GocbV2Bucket) MaxTTL() (int, error) { - return getMaxTTL(b) +func (b *GocbV2Bucket) MaxTTL(ctx context.Context) (int, error) { + return getMaxTTL(ctx, b) } -func (b *GocbV2Bucket) HttpClient() *http.Client { +func (b *GocbV2Bucket) HttpClient(ctx context.Context) *http.Client { agent, err := b.getGoCBAgent() if err != nil { - WarnfCtx(context.TODO(), "Unable to obtain gocbcore.Agent while retrieving httpClient:%v", err) + WarnfCtx(ctx, "Unable to obtain gocbcore.Agent while retrieving httpClient:%v", err) return nil } return agent.HTTPClient() @@ -465,7 +465,7 @@ func (b *GocbV2Bucket) BucketName() string { return b.GetName() } -func (b *GocbV2Bucket) mgmtRequest(method, uri, contentType string, body io.Reader) (*http.Response, error) { +func (b *GocbV2Bucket) mgmtRequest(ctx context.Context, method, uri, contentType string, body io.Reader) (*http.Response, error) { if contentType == "" && body != nil { // TODO: CBG-1948 panic("Content-type must be specified for non-null body.") @@ -490,7 +490,7 @@ func (b *GocbV2Bucket) mgmtRequest(method, uri, contentType string, body io.Read req.SetBasicAuth(username, password) } - return b.HttpClient().Do(req) + return b.HttpClient(ctx).Do(req) } // This prevents Sync Gateway from overflowing gocb's pipeline @@ -601,6 +601,7 @@ func (b *GocbV2Bucket) DropDataStore(name sgbucket.DataStoreName) error { } func (b *GocbV2Bucket) CreateDataStore(name sgbucket.DataStoreName) error { + ctx := context.TODO() // fix in sg-bucket // create scope first (if it doesn't already exist) if name.ScopeName() != DefaultScope { err := b.bucket.Collections().CreateScope(name.ScopeName(), nil) @@ -614,7 +615,7 @@ func (b *GocbV2Bucket) CreateDataStore(name sgbucket.DataStoreName) error { } // Can't use Collection.Exists since we can't get a collection until the collection exists on CBS gocbCollection := b.bucket.Scope(name.ScopeName()).Collection(name.CollectionName()) - return WaitForNoError(func() error { + return WaitForNoError(ctx, func() error { _, err := gocbCollection.Exists("fakedocid", nil) return err }) diff --git a/base/collection_gocb.go b/base/collection_gocb.go index 7c4c1693f8..fa4f574efe 100644 --- a/base/collection_gocb.go +++ b/base/collection_gocb.go @@ -40,8 +40,9 @@ type Collection struct { // Ensure that Collection implements sgbucket.DataStore/N1QLStore var ( - _ DataStore = &Collection{} - _ N1QLStore = &Collection{} + _ DataStore = &Collection{} + _ N1QLStore = &Collection{} + _ sgbucket.ViewStore = &Collection{} ) func AsCollection(dataStore DataStore) (*Collection, error) { @@ -412,7 +413,8 @@ func (c *Collection) isRecoverableWriteError(err error) bool { func (c *Collection) GetExpiry(k string) (expiry uint32, getMetaError error) { agent, err := c.Bucket.getGoCBAgent() if err != nil { - WarnfCtx(context.TODO(), "Unable to obtain gocbcore.Agent while retrieving expiry:%v", err) + ctx := context.TODO() // fix in sg-bucket + WarnfCtx(ctx, "Unable to obtain gocbcore.Agent while retrieving expiry:%v", err) return 0, err } diff --git a/base/collection_n1ql_common.go b/base/collection_n1ql_common.go index f427d3a8bd..3bd26923ce 100644 --- a/base/collection_n1ql_common.go +++ b/base/collection_n1ql_common.go @@ -189,7 +189,7 @@ func waitForIndexExistence(ctx context.Context, store N1QLStore, indexName strin } // Kick off retry loop - err, _ := RetryLoop("waitForIndexExistence", worker, CreateMaxDoublingSleeperFunc(25, 100, 15000)) + err, _ := RetryLoop(ctx, "waitForIndexExistence", worker, CreateMaxDoublingSleeperFunc(25, 100, 15000)) if err != nil { return pkgerrors.Wrapf(err, "Error during waitForIndexExistence for index %s", indexName) } @@ -262,7 +262,7 @@ func buildIndexes(ctx context.Context, s N1QLStore, indexNames []string) error { // If indexer reports build will be completed in the background, wait to validate build actually happens. if IsIndexerRetryBuildError(err) { - InfofCtx(context.TODO(), KeyQuery, "Indexer error creating index - waiting for background build. Error:%v", err) + InfofCtx(ctx, KeyQuery, "Indexer error creating index - waiting for background build. Error:%v", err) // Wait for bucket to be created in background before returning return s.WaitForIndexesOnline(ctx, indexNames, false) } @@ -302,7 +302,7 @@ func GetIndexMeta(ctx context.Context, store N1QLStore, indexName string) (exist } // Kick off retry loop - err, val := RetryLoop("GetIndexMeta", worker, CreateMaxDoublingSleeperFunc(25, 100, 15000)) + err, val := RetryLoop(ctx, "GetIndexMeta", worker, CreateMaxDoublingSleeperFunc(25, 100, 15000)) if err != nil { return false, nil, pkgerrors.Wrapf(err, "Error during GetIndexMeta for index %s", indexName) } @@ -495,9 +495,10 @@ func (i *gocbRawIterator) Next(valuePtr interface{}) bool { return false } + ctx := context.TODO() // fix in sg-bucket err := JSONUnmarshal(nextBytes, &valuePtr) if err != nil { - WarnfCtx(context.TODO(), "Unable to marshal view result row into value: %v", err) + WarnfCtx(ctx, "Unable to marshal view result row into value: %v", err) return false } return true diff --git a/base/collection_view.go b/base/collection_view.go index 9f8c86d1e0..03ff1ad24f 100644 --- a/base/collection_view.go +++ b/base/collection_view.go @@ -80,6 +80,7 @@ func (c *Collection) GetDDocs() (ddocs map[string]sgbucket.DesignDoc, err error) } func (c *Collection) PutDDoc(docname string, sgDesignDoc *sgbucket.DesignDoc) error { + ctx := context.TODO() // fix in sg-bucket if !c.IsDefaultScopeCollection() { return fmt.Errorf("views not supported for non-default collection") } @@ -100,20 +101,20 @@ func (c *Collection) PutDDoc(docname string, sgDesignDoc *sgbucket.DesignDoc) er // If design doc needs to be tombstone-aware, requires custom creation* if sgDesignDoc.Options != nil && sgDesignDoc.Options.IndexXattrOnTombstones { - return c.Bucket.putDDocForTombstones(&gocbDesignDoc) + return c.Bucket.putDDocForTombstones(ctx, &gocbDesignDoc) } // Retry for all errors (The view service sporadically returns 500 status codes with Erlang errors (for unknown reasons) - E.g: 500 {"error":"case_clause","reason":"false"}) var worker RetryWorker = func() (bool, error, interface{}) { err := manager.UpsertDesignDocument(gocbDesignDoc, gocb.DesignDocumentNamespaceProduction, nil) if err != nil { - WarnfCtx(context.Background(), "Got error from UpsertDesignDocument: %v - Retrying...", err) + WarnfCtx(ctx, "Got error from UpsertDesignDocument: %v - Retrying...", err) return true, err, nil } return false, nil, nil } - err, _ := RetryLoop("PutDDocRetryLoop", worker, CreateSleeperFunc(5, 100)) + err, _ := RetryLoop(ctx, "PutDDocRetryLoop", worker, CreateSleeperFunc(5, 100)) return err } @@ -159,7 +160,7 @@ type NoNameDesignDocument struct { Views map[string]NoNameView `json:"views"` } -func (b *GocbV2Bucket) putDDocForTombstones(ddoc *gocb.DesignDocument) error { +func (b *GocbV2Bucket) putDDocForTombstones(ctx context.Context, ddoc *gocb.DesignDocument) error { username, password, _ := b.Spec.Auth.GetCredentials() agent, err := b.getGoCBAgent() if err != nil { @@ -177,7 +178,7 @@ func (b *GocbV2Bucket) putDDocForTombstones(ddoc *gocb.DesignDocument) error { return err } - return putDDocForTombstones(ddoc.Name, data, agent.CapiEps(), agent.HTTPClient(), username, password) + return putDDocForTombstones(ctx, ddoc.Name, data, agent.CapiEps(), agent.HTTPClient(), username, password) } @@ -190,9 +191,9 @@ func (c *Collection) DeleteDDoc(docname string) error { } func (c *Collection) View(ddoc, name string, params map[string]interface{}) (sgbucket.ViewResult, error) { - + ctx := context.TODO() // fix in sg-bucket var viewResult sgbucket.ViewResult - gocbViewResult, err := c.executeViewQuery(ddoc, name, params) + gocbViewResult, err := c.executeViewQuery(ctx, ddoc, name, params) if err != nil { return viewResult, err } @@ -221,7 +222,7 @@ func (c *Collection) View(ddoc, name string, params map[string]interface{}) (sgb viewMeta, err := unmarshalViewMetadata(gocbViewResult) if err != nil { - WarnfCtx(context.TODO(), "Unable to type get metadata for gocb ViewResult - the total rows count will be missing.") + WarnfCtx(ctx, "Unable to type get metadata for gocb ViewResult - the total rows count will be missing.") } else { viewResult.TotalRows = viewMeta.TotalRows } @@ -248,19 +249,19 @@ func unmarshalViewMetadata(viewResult *gocb.ViewResultRaw) (viewMetadata, error) } func (c *Collection) ViewQuery(ddoc, name string, params map[string]interface{}) (sgbucket.QueryResultIterator, error) { - - gocbViewResult, err := c.executeViewQuery(ddoc, name, params) + ctx := context.TODO() // fix in sg-bucket + gocbViewResult, err := c.executeViewQuery(ctx, ddoc, name, params) if err != nil { return nil, err } return &gocbRawIterator{rawResult: gocbViewResult, concurrentQueryOpLimitChan: c.Bucket.queryOps}, nil } -func (c *Collection) executeViewQuery(ddoc, name string, params map[string]interface{}) (*gocb.ViewResultRaw, error) { +func (c *Collection) executeViewQuery(ctx context.Context, ddoc, name string, params map[string]interface{}) (*gocb.ViewResultRaw, error) { viewResult := sgbucket.ViewResult{} viewResult.Rows = sgbucket.ViewRows{} - viewOpts, optsErr := createViewOptions(params) + viewOpts, optsErr := createViewOptions(ctx, params) if optsErr != nil { return nil, optsErr } @@ -281,37 +282,37 @@ func (c *Collection) executeViewQuery(ddoc, name string, params map[string]inter } // Applies the viewquery options as specified in the params map to the gocb.ViewOptions -func createViewOptions(params map[string]interface{}) (viewOpts *gocb.ViewOptions, err error) { +func createViewOptions(ctx context.Context, params map[string]interface{}) (viewOpts *gocb.ViewOptions, err error) { viewOpts = &gocb.ViewOptions{} for optionName, optionValue := range params { switch optionName { case ViewQueryParamStale: - viewOpts.ScanConsistency = asViewConsistency(optionValue) + viewOpts.ScanConsistency = asViewConsistency(ctx, optionValue) case ViewQueryParamReduce: - viewOpts.Reduce = asBool(optionValue) + viewOpts.Reduce = asBool(ctx, optionValue) case ViewQueryParamLimit: uintVal, err := normalizeIntToUint(optionValue) if err != nil { - WarnfCtx(context.Background(), "ViewQueryParamLimit error: %v", err) + WarnfCtx(ctx, "ViewQueryParamLimit error: %v", err) } viewOpts.Limit = uint32(uintVal) case ViewQueryParamDescending: - if asBool(optionValue) == true { + if asBool(ctx, optionValue) == true { viewOpts.Order = gocb.ViewOrderingDescending } case ViewQueryParamSkip: uintVal, err := normalizeIntToUint(optionValue) if err != nil { - WarnfCtx(context.Background(), "ViewQueryParamSkip error: %v", err) + WarnfCtx(ctx, "ViewQueryParamSkip error: %v", err) } viewOpts.Skip = uint32(uintVal) case ViewQueryParamGroup: - viewOpts.Group = asBool(optionValue) + viewOpts.Group = asBool(ctx, optionValue) case ViewQueryParamGroupLevel: uintVal, err := normalizeIntToUint(optionValue) if err != nil { - WarnfCtx(context.Background(), "ViewQueryParamGroupLevel error: %v", err) + WarnfCtx(ctx, "ViewQueryParamGroupLevel error: %v", err) } viewOpts.GroupLevel = uint32(uintVal) case ViewQueryParamKey: @@ -343,7 +344,7 @@ func createViewOptions(params map[string]interface{}) (viewOpts *gocb.ViewOption // Default value of inclusiveEnd in Couchbase Server is true (if not specified) inclusiveEnd := true if _, ok := params[ViewQueryParamInclusiveEnd]; ok { - inclusiveEnd = asBool(params[ViewQueryParamInclusiveEnd]) + inclusiveEnd = asBool(ctx, params[ViewQueryParamInclusiveEnd]) } viewOpts.StartKey = startKey viewOpts.EndKey = endKey @@ -364,7 +365,7 @@ func createViewOptions(params map[string]interface{}) (viewOpts *gocb.ViewOption } // Used to convert the stale view parameter to a gocb ViewScanConsistency -func asViewConsistency(value interface{}) gocb.ViewScanConsistency { +func asViewConsistency(ctx context.Context, value interface{}) gocb.ViewScanConsistency { switch typeValue := value.(type) { case string: @@ -376,7 +377,7 @@ func asViewConsistency(value interface{}) gocb.ViewScanConsistency { } parsedVal, err := strconv.ParseBool(typeValue) if err != nil { - WarnfCtx(context.Background(), "asStale called with unknown value: %v. defaulting to stale=false", typeValue) + WarnfCtx(ctx, "asStale called with unknown value: %v. defaulting to stale=false", typeValue) return gocb.ViewScanConsistencyRequestPlus } if parsedVal { @@ -391,7 +392,7 @@ func asViewConsistency(value interface{}) gocb.ViewScanConsistency { return gocb.ViewScanConsistencyRequestPlus } default: - WarnfCtx(context.Background(), "asViewConsistency called with unknown type: %T. defaulting to RequestPlus", typeValue) + WarnfCtx(ctx, "asViewConsistency called with unknown type: %T. defaulting to RequestPlus", typeValue) return gocb.ViewScanConsistencyRequestPlus } diff --git a/base/collection_xattr.go b/base/collection_xattr.go index 0262fd6d9c..857f7c22b4 100644 --- a/base/collection_xattr.go +++ b/base/collection_xattr.go @@ -45,15 +45,18 @@ func (c *Collection) GetSpec() BucketSpec { // Implementation of the XattrStore interface primarily invokes common wrappers that in turn invoke SDK-specific SubdocXattrStore API func (c *Collection) WriteCasWithXattr(k string, xattrKey string, exp uint32, cas uint64, opts *sgbucket.MutateInOptions, v interface{}, xv interface{}) (casOut uint64, err error) { - return WriteCasWithXattr(c, k, xattrKey, exp, cas, opts, v, xv) + ctx := context.TODO() // fix in sg-bucket + return WriteCasWithXattr(ctx, c, k, xattrKey, exp, cas, opts, v, xv) } func (c *Collection) WriteWithXattr(k string, xattrKey string, exp uint32, cas uint64, opts *sgbucket.MutateInOptions, v []byte, xv []byte, isDelete bool, deleteBody bool) (casOut uint64, err error) { // If this is a tombstone, we want to delete the document and update the xattr - return WriteWithXattr(c, k, xattrKey, exp, cas, opts, v, xv, isDelete, deleteBody) + ctx := context.TODO() // fix in sg-bucket + return WriteWithXattr(ctx, c, k, xattrKey, exp, cas, opts, v, xv, isDelete, deleteBody) } func (c *Collection) DeleteWithXattr(k string, xattrKey string) error { - return DeleteWithXattr(c, k, xattrKey) + ctx := context.TODO() // fix in sg-bucket + return DeleteWithXattr(ctx, c, k, xattrKey) } func (c *Collection) GetXattr(k string, xattrKey string, xv interface{}) (casOut uint64, err error) { @@ -61,11 +64,13 @@ func (c *Collection) GetXattr(k string, xattrKey string, xv interface{}) (casOut } func (c *Collection) GetSubDocRaw(k string, subdocKey string) ([]byte, uint64, error) { - return c.SubdocGetRaw(k, subdocKey) + ctx := context.TODO() // fix in sg-bucket + return c.SubdocGetRaw(ctx, k, subdocKey) } func (c *Collection) WriteSubDoc(k string, subdocKey string, cas uint64, value []byte) (uint64, error) { - return c.SubdocWrite(k, subdocKey, cas, value) + ctx := context.TODO() // fix in sg-bucket + return c.SubdocWrite(ctx, k, subdocKey, cas, value) } func (c *Collection) GetWithXattr(k string, xattrKey string, userXattrKey string, rv interface{}, xv interface{}, uxv interface{}) (cas uint64, err error) { @@ -73,19 +78,23 @@ func (c *Collection) GetWithXattr(k string, xattrKey string, userXattrKey string } func (c *Collection) WriteUpdateWithXattr(k string, xattrKey string, userXattrKey string, exp uint32, opts *sgbucket.MutateInOptions, previous *sgbucket.BucketDocument, callback sgbucket.WriteUpdateWithXattrFunc) (casOut uint64, err error) { - return WriteUpdateWithXattr(c, k, xattrKey, userXattrKey, exp, opts, previous, callback) + ctx := context.TODO() // fix in sg-bucket + return WriteUpdateWithXattr(ctx, c, k, xattrKey, userXattrKey, exp, opts, previous, callback) } func (c *Collection) SetXattr(k string, xattrKey string, xv []byte) (casOut uint64, err error) { - return SetXattr(c, k, xattrKey, xv) + ctx := context.TODO() // fix in sg-bucket + return SetXattr(ctx, c, k, xattrKey, xv) } func (c *Collection) RemoveXattr(k string, xattrKey string, cas uint64) (err error) { - return RemoveXattr(c, k, xattrKey, cas) + ctx := context.TODO() // fix in sg-bucket + return RemoveXattr(ctx, c, k, xattrKey, cas) } func (c *Collection) DeleteXattrs(k string, xattrKeys ...string) (err error) { - return DeleteXattrs(c, k, xattrKeys...) + ctx := context.TODO() // fix in sg-bucket + return DeleteXattrs(ctx, c, k, xattrKeys...) } // SubdocGetXattr retrieves the named xattr @@ -96,6 +105,7 @@ func (c *Collection) SubdocGetXattr(k string, xattrKey string, xv interface{}) ( c.Bucket.waitForAvailKvOp() defer c.Bucket.releaseKvOp() + ctx := context.TODO() // fix in sg-bucket ops := []gocb.LookupInSpec{ gocb.GetSpec(xattrKey, GetSpecXattr), } @@ -104,20 +114,20 @@ func (c *Collection) SubdocGetXattr(k string, xattrKey string, xv interface{}) ( xattrContErr := res.ContentAt(0, xv) // On error here, treat as the xattr wasn't found if xattrContErr != nil { - DebugfCtx(context.TODO(), KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContErr) + DebugfCtx(ctx, KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContErr) return 0, ErrXattrNotFound } cas := uint64(res.Cas()) return cas, nil } else if errors.Is(lookupErr, gocbcore.ErrDocumentNotFound) { - DebugfCtx(context.TODO(), KeyCRUD, "No document found for key=%s", UD(k)) + DebugfCtx(ctx, KeyCRUD, "No document found for key=%s", UD(k)) return 0, ErrNotFound } else { return 0, lookupErr } } -func (c *Collection) SubdocGetRaw(k string, subdocKey string) ([]byte, uint64, error) { +func (c *Collection) SubdocGetRaw(ctx context.Context, k string, subdocKey string) ([]byte, uint64, error) { c.Bucket.waitForAvailKvOp() defer c.Bucket.releaseKvOp() @@ -150,7 +160,7 @@ func (c *Collection) SubdocGetRaw(k string, subdocKey string) ([]byte, uint64, e return false, nil, uint64(res.Cas()) } - err, casOut := RetryLoopCas("SubdocGetRaw", worker, c.Bucket.Spec.RetrySleeper()) + err, casOut := RetryLoopCas(ctx, "SubdocGetRaw", worker, c.Bucket.Spec.RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "SubdocGetRaw with key %s and subdocKey %s", UD(k).Redact(), UD(subdocKey).Redact()) } @@ -158,7 +168,7 @@ func (c *Collection) SubdocGetRaw(k string, subdocKey string) ([]byte, uint64, e return rawValue, casOut, err } -func (c *Collection) SubdocWrite(k string, subdocKey string, cas uint64, value []byte) (uint64, error) { +func (c *Collection) SubdocWrite(ctx context.Context, k string, subdocKey string, cas uint64, value []byte) (uint64, error) { c.Bucket.waitForAvailKvOp() defer c.Bucket.releaseKvOp() @@ -183,7 +193,7 @@ func (c *Collection) SubdocWrite(k string, subdocKey string, cas uint64, value [ return false, err, 0 } - err, casOut := RetryLoopCas("SubdocWrite", worker, c.Bucket.Spec.RetrySleeper()) + err, casOut := RetryLoopCas(ctx, "SubdocWrite", worker, c.Bucket.Spec.RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "SubdocWrite with key %s and subdocKey %s", UD(k).Redact(), UD(subdocKey).Redact()) } @@ -193,6 +203,7 @@ func (c *Collection) SubdocWrite(k string, subdocKey string, cas uint64, value [ // SubdocGetBodyAndXattr retrieves the document body and xattr in a single LookupIn subdoc operation. Does not require both to exist. func (c *Collection) SubdocGetBodyAndXattr(k string, xattrKey string, userXattrKey string, rv interface{}, xv interface{}, uxv interface{}) (cas uint64, err error) { + ctx := context.TODO() // fix in sg-bucket worker := func() (shouldRetry bool, err error, value uint64) { c.Bucket.waitForAvailKvOp() @@ -217,16 +228,16 @@ func (c *Collection) SubdocGetBodyAndXattr(k string, xattrKey string, userXattrK if isKVError(docContentErr, memd.StatusSubDocMultiPathFailureDeleted) && isKVError(xattrContentErr, memd.StatusSubDocMultiPathFailureDeleted) { // No doc, no xattr can be treated as NotFound from Sync Gateway's perspective, even if it is a server tombstone, but should return cas - DebugfCtx(context.TODO(), KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContentErr) + DebugfCtx(ctx, KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContentErr) return false, ErrNotFound, cas } if docContentErr != nil { - DebugfCtx(context.TODO(), KeyCRUD, "No document body found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), docContentErr) + DebugfCtx(ctx, KeyCRUD, "No document body found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), docContentErr) } // Attempt to retrieve the xattr, if present if xattrContentErr != nil { - DebugfCtx(context.TODO(), KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContentErr) + DebugfCtx(ctx, KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContentErr) } case gocbcore.ErrMemdSubDocMultiPathFailureDeleted: @@ -235,7 +246,7 @@ func (c *Collection) SubdocGetBodyAndXattr(k string, xattrKey string, userXattrK cas = uint64(res.Cas()) if xattrContentErr != nil { // No doc, no xattr means the doc isn't found - DebugfCtx(context.TODO(), KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContentErr) + DebugfCtx(ctx, KeyCRUD, "No xattr content found for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), xattrContentErr) return false, ErrNotFound, cas } return false, nil, cas @@ -272,7 +283,7 @@ func (c *Collection) SubdocGetBodyAndXattr(k string, xattrKey string, userXattrK } // Kick off retry loop - err, cas = RetryLoopCas("SubdocGetBodyAndXattr", worker, c.Bucket.Spec.RetrySleeper()) + err, cas = RetryLoopCas(ctx, "SubdocGetBodyAndXattr", worker, c.Bucket.Spec.RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "SubdocGetBodyAndXattr %v", UD(k).Redact()) } diff --git a/base/collection_xattr_common.go b/base/collection_xattr_common.go index 146842c128..18c90fb92f 100644 --- a/base/collection_xattr_common.go +++ b/base/collection_xattr_common.go @@ -31,7 +31,7 @@ type KvXattrStore interface { } // CAS-safe write of a document and it's associated named xattr -func WriteCasWithXattr(store *Collection, k string, xattrKey string, exp uint32, cas uint64, opts *sgbucket.MutateInOptions, v interface{}, xv interface{}) (casOut uint64, err error) { +func WriteCasWithXattr(ctx context.Context, store *Collection, k string, xattrKey string, exp uint32, cas uint64, opts *sgbucket.MutateInOptions, v interface{}, xv interface{}) (casOut uint64, err error) { worker := func() (shouldRetry bool, err error, value uint64) { @@ -65,7 +65,7 @@ func WriteCasWithXattr(store *Collection, k string, xattrKey string, exp uint32, } // Kick off retry loop - err, cas = RetryLoopCas("WriteCasWithXattr", worker, store.GetSpec().RetrySleeper()) + err, cas = RetryLoopCas(ctx, "WriteCasWithXattr", worker, store.GetSpec().RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "WriteCasWithXattr with key %v", UD(k).Redact()) } @@ -75,17 +75,17 @@ func WriteCasWithXattr(store *Collection, k string, xattrKey string, exp uint32, // Single attempt to update a document and xattr. Setting isDelete=true and value=nil will delete the document body. Both // update types (UpdateTombstoneXattr, WriteCasWithXattr) include recoverable error retry. -func WriteWithXattr(store *Collection, k string, xattrKey string, exp uint32, cas uint64, opts *sgbucket.MutateInOptions, value []byte, xattrValue []byte, isDelete bool, deleteBody bool) (casOut uint64, err error) { // If this is a tombstone, we want to delete the document and update the xattr +func WriteWithXattr(ctx context.Context, store *Collection, k string, xattrKey string, exp uint32, cas uint64, opts *sgbucket.MutateInOptions, value []byte, xattrValue []byte, isDelete bool, deleteBody bool) (casOut uint64, err error) { // If this is a tombstone, we want to delete the document and update the xattr if isDelete { - return UpdateTombstoneXattr(store, k, xattrKey, exp, cas, xattrValue, deleteBody) + return UpdateTombstoneXattr(ctx, store, k, xattrKey, exp, cas, xattrValue, deleteBody) } else { // Not a delete - update the body and xattr - return WriteCasWithXattr(store, k, xattrKey, exp, cas, opts, value, xattrValue) + return WriteCasWithXattr(ctx, store, k, xattrKey, exp, cas, opts, value, xattrValue) } } // CAS-safe update of a document's xattr (only). Deletes the document body if deleteBody is true. -func UpdateTombstoneXattr(store *Collection, k string, xattrKey string, exp uint32, cas uint64, xv interface{}, deleteBody bool) (casOut uint64, err error) { +func UpdateTombstoneXattr(ctx context.Context, store *Collection, k string, xattrKey string, exp uint32, cas uint64, xv interface{}, deleteBody bool) (casOut uint64, err error) { // WriteCasWithXattr always stamps the xattr with the new cas using macro expansion, into a top-level property called 'cas'. // This is the only use case for macro expansion today - if more cases turn up, should change the sg-bucket API to handle this more generically. @@ -118,7 +118,7 @@ func UpdateTombstoneXattr(store *Collection, k string, xattrKey string, exp uint } // Kick off retry loop - err, cas = RetryLoopCas("UpdateTombstoneXattr", worker, store.GetSpec().RetrySleeper()) + err, cas = RetryLoopCas(ctx, "UpdateTombstoneXattr", worker, store.GetSpec().RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "Error during UpdateTombstoneXattr with key %v", UD(k).Redact()) return cas, err @@ -145,7 +145,7 @@ func UpdateTombstoneXattr(store *Collection, k string, xattrKey string, exp uint return false, nil, casOut } - err, cas = RetryLoopCas("UpdateXattrDeleteBodySecondOp", worker, store.GetSpec().RetrySleeper()) + err, cas = RetryLoopCas(ctx, "UpdateXattrDeleteBodySecondOp", worker, store.GetSpec().RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "Error during UpdateTombstoneXattr delete op with key %v", UD(k).Redact()) return cas, err @@ -164,7 +164,7 @@ func UpdateTombstoneXattr(store *Collection, k string, xattrKey string, exp uint // A zero CAS in `previous` is interpreted as no document existing; this can be used to short- // circuit the initial Get when the document is unlikely to already exist. -func WriteUpdateWithXattr(store *Collection, k string, xattrKey string, userXattrKey string, exp uint32, opts *sgbucket.MutateInOptions, previous *sgbucket.BucketDocument, callback sgbucket.WriteUpdateWithXattrFunc) (casOut uint64, err error) { +func WriteUpdateWithXattr(ctx context.Context, store *Collection, k string, xattrKey string, userXattrKey string, exp uint32, opts *sgbucket.MutateInOptions, previous *sgbucket.BucketDocument, callback sgbucket.WriteUpdateWithXattrFunc) (casOut uint64, err error) { var value []byte var xattrValue []byte @@ -192,7 +192,7 @@ func WriteUpdateWithXattr(store *Collection, k string, xattrKey string, userXatt if err != nil { if pkgerrors.Cause(err) != ErrNotFound { // Unexpected error, cancel writeupdate - DebugfCtx(context.TODO(), KeyCRUD, "Retrieval of existing doc failed during WriteUpdateWithXattr for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), err) + DebugfCtx(ctx, KeyCRUD, "Retrieval of existing doc failed during WriteUpdateWithXattr for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), err) return emptyCas, err } // Key not found - initialize values @@ -219,7 +219,7 @@ func WriteUpdateWithXattr(store *Collection, k string, xattrKey string, userXatt // Attempt to write the updated document to the bucket. Mark body for deletion if previous body was non-empty deleteBody := value != nil - casOut, writeErr := WriteWithXattr(store, k, xattrKey, exp, cas, opts, updatedValue, updatedXattrValue, isDelete, deleteBody) + casOut, writeErr := WriteWithXattr(ctx, store, k, xattrKey, exp, cas, opts, updatedValue, updatedXattrValue, isDelete, deleteBody) if writeErr == nil { return casOut, nil @@ -231,7 +231,7 @@ func WriteUpdateWithXattr(store *Collection, k string, xattrKey string, userXatt // conflict/duplicate handling on retry. } else { // WriteWithXattr already handles retry on recoverable errors, so fail on any errors other than ErrKeyExists - WarnfCtx(context.TODO(), "Failed to update doc with xattr for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), writeErr) + WarnfCtx(ctx, "Failed to update doc with xattr for key=%s, xattrKey=%s: %v", UD(k), UD(xattrKey), writeErr) return emptyCas, writeErr } @@ -243,7 +243,7 @@ func WriteUpdateWithXattr(store *Collection, k string, xattrKey string, userXatt } // SetXattr performs a subdoc set on the supplied xattrKey. Implements a retry for recoverable failures. -func SetXattr(store *Collection, k string, xattrKey string, xv []byte) (casOut uint64, err error) { +func SetXattr(ctx context.Context, store *Collection, k string, xattrKey string, xv []byte) (casOut uint64, err error) { worker := func() (shouldRetry bool, err error, value uint64) { casOut, writeErr := store.SubdocSetXattr(k, xattrKey, xv) @@ -259,7 +259,7 @@ func SetXattr(store *Collection, k string, xattrKey string, xv []byte) (casOut u return false, writeErr, 0 } - err, casOut = RetryLoopCas("SetXattr", worker, store.GetSpec().RetrySleeper()) + err, casOut = RetryLoopCas(ctx, "SetXattr", worker, store.GetSpec().RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "SetXattr with key %v", UD(k).Redact()) } @@ -269,7 +269,7 @@ func SetXattr(store *Collection, k string, xattrKey string, xv []byte) (casOut u } // RemoveXattr performs a cas safe subdoc delete of the provided key. Will retry if a recoverable failure occurs. -func RemoveXattr(store *Collection, k string, xattrKey string, cas uint64) error { +func RemoveXattr(ctx context.Context, store *Collection, k string, xattrKey string, cas uint64) error { worker := func() (shouldRetry bool, err error, value interface{}) { writeErr := store.SubdocDeleteXattr(k, xattrKey, cas) if writeErr == nil { @@ -284,7 +284,7 @@ func RemoveXattr(store *Collection, k string, xattrKey string, cas uint64) error return false, err, nil } - err, _ := RetryLoop("RemoveXattr", worker, store.GetSpec().RetrySleeper()) + err, _ := RetryLoop(ctx, "RemoveXattr", worker, store.GetSpec().RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "RemoveXattr with key %v xattr %v", UD(k).Redact(), UD(xattrKey).Redact()) } @@ -294,7 +294,7 @@ func RemoveXattr(store *Collection, k string, xattrKey string, cas uint64) error // DeleteXattrs performs a subdoc delete of the provided keys. Retries any recoverable failures. Not cas safe does a // straight delete. -func DeleteXattrs(store *Collection, k string, xattrKeys ...string) error { +func DeleteXattrs(ctx context.Context, store *Collection, k string, xattrKeys ...string) error { worker := func() (shouldRetry bool, err error, value interface{}) { writeErr := store.SubdocDeleteXattrs(k, xattrKeys...) if writeErr == nil { @@ -309,7 +309,7 @@ func DeleteXattrs(store *Collection, k string, xattrKeys ...string) error { return false, err, nil } - err, _ := RetryLoop("DeleteXattrs", worker, store.GetSpec().RetrySleeper()) + err, _ := RetryLoop(ctx, "DeleteXattrs", worker, store.GetSpec().RetrySleeper()) if err != nil { err = pkgerrors.Wrapf(err, "DeleteXattrs with keys %q xattr %v", UD(k).Redact(), UD(strings.Join(xattrKeys, ",")).Redact()) } @@ -331,18 +331,18 @@ func DeleteXattrs(store *Collection, k string, xattrKeys ...string) error { // Expected errors: // - Temporary server overloaded errors, in which case the caller should retry // - If the doc is in the the NoDoc and NoXattr state, it will return a KeyNotFound error -func DeleteWithXattr(store *Collection, k string, xattrKey string) error { +func DeleteWithXattr(ctx context.Context, store *Collection, k string, xattrKey string) error { // Delegate to internal method that can take a testing-related callback - return deleteWithXattrInternal(store, k, xattrKey, nil) + return deleteWithXattrInternal(ctx, store, k, xattrKey, nil) } // A function that will be called back after the first delete attempt but before second delete attempt // to simulate the doc having changed state (artifiically injected race condition) type deleteWithXattrRaceInjection func(k string, xattrKey string) -func deleteWithXattrInternal(store *Collection, k string, xattrKey string, callback deleteWithXattrRaceInjection) error { +func deleteWithXattrInternal(ctx context.Context, store *Collection, k string, xattrKey string, callback deleteWithXattrRaceInjection) error { - DebugfCtx(context.TODO(), KeyCRUD, "DeleteWithXattr called with key: %v xattrKey: %v", UD(k), UD(xattrKey)) + DebugfCtx(ctx, KeyCRUD, "DeleteWithXattr called with key: %v xattrKey: %v", UD(k), UD(xattrKey)) // Try to delete body and xattrs in single op // NOTE: ongoing discussion w/ KV Engine team on whether this should handle cases where the body diff --git a/base/config_persistence.go b/base/config_persistence.go index 4c2a28b9da..d262f84ccc 100644 --- a/base/config_persistence.go +++ b/base/config_persistence.go @@ -24,13 +24,13 @@ import ( type ConfigPersistence interface { // Operations for interacting with raw config ([]byte). gocb.Cas values represent document cas, // cfgCas represent the cas value associated with the last mutation, and may not match document CAS - loadRawConfig(c *gocb.Collection, key string) ([]byte, gocb.Cas, error) + loadRawConfig(ctx context.Context, c *gocb.Collection, key string) ([]byte, gocb.Cas, error) removeRawConfig(c *gocb.Collection, key string, cas gocb.Cas) (gocb.Cas, error) replaceRawConfig(c *gocb.Collection, key string, value []byte, cas gocb.Cas) (casOut gocb.Cas, err error) // Operations for interacting with marshalled config. cfgCas represents the cas value // associated with the last config mutation, and may not match document CAS - loadConfig(c *gocb.Collection, key string, valuePtr interface{}) (cfgCas uint64, err error) + loadConfig(ctx context.Context, c *gocb.Collection, key string, valuePtr interface{}) (cfgCas uint64, err error) insertConfig(c *gocb.Collection, key string, value interface{}) (cfgCas uint64, err error) // touchConfigRollback sets the specific property to the specified string value via a subdoc operation. @@ -93,7 +93,7 @@ func (xbp *XattrBootstrapPersistence) touchConfigRollback(c *gocb.Collection, ke // loadRawConfig returns the config and document cas (not cfgCas). Does not restore deleted documents, // to avoid cas collisions with concurrent updates -func (xbp *XattrBootstrapPersistence) loadRawConfig(c *gocb.Collection, key string) ([]byte, gocb.Cas, error) { +func (xbp *XattrBootstrapPersistence) loadRawConfig(ctx context.Context, c *gocb.Collection, key string) ([]byte, gocb.Cas, error) { var rawValue []byte ops := []gocb.LookupInSpec{ @@ -109,12 +109,12 @@ func (xbp *XattrBootstrapPersistence) loadRawConfig(c *gocb.Collection, key stri // config xattrContErr := res.ContentAt(0, &rawValue) if xattrContErr != nil { - DebugfCtx(context.TODO(), KeyCRUD, "No xattr config found for key=%s, path=%s: %v", key, cfgXattrConfigPath, xattrContErr) + DebugfCtx(ctx, KeyCRUD, "No xattr config found for key=%s, path=%s: %v", key, cfgXattrConfigPath, xattrContErr) return rawValue, 0, ErrNotFound } return rawValue, res.Cas(), nil } else if errors.Is(lookupErr, gocbcore.ErrDocumentNotFound) { - DebugfCtx(context.TODO(), KeyCRUD, "No config document found for key=%s", key) + DebugfCtx(ctx, KeyCRUD, "No config document found for key=%s", key) return rawValue, 0, ErrNotFound } else { return rawValue, 0, lookupErr @@ -167,7 +167,7 @@ func (xbp *XattrBootstrapPersistence) replaceRawConfig(c *gocb.Collection, key s // loadConfig returns the cas associated with the last cfg change (xattr._sync.cas). If a deleted document body is // detected, recreates the document to avoid metadata purge -func (xbp *XattrBootstrapPersistence) loadConfig(c *gocb.Collection, key string, valuePtr interface{}) (cas uint64, err error) { +func (xbp *XattrBootstrapPersistence) loadConfig(ctx context.Context, c *gocb.Collection, key string, valuePtr interface{}) (cas uint64, err error) { ops := []gocb.LookupInSpec{ gocb.GetSpec(cfgXattrConfigPath, GetSpecXattr), @@ -184,7 +184,7 @@ func (xbp *XattrBootstrapPersistence) loadConfig(c *gocb.Collection, key string, // config xattrContErr := res.ContentAt(0, valuePtr) if xattrContErr != nil { - DebugfCtx(context.TODO(), KeyCRUD, "No xattr config found for key=%s, path=%s: %v", key, cfgXattrConfigPath, xattrContErr) + DebugfCtx(ctx, KeyCRUD, "No xattr config found for key=%s, path=%s: %v", key, cfgXattrConfigPath, xattrContErr) return 0, ErrNotFound } @@ -192,7 +192,7 @@ func (xbp *XattrBootstrapPersistence) loadConfig(c *gocb.Collection, key string, var strCas string xattrCasErr := res.ContentAt(1, &strCas) if xattrCasErr != nil { - DebugfCtx(context.TODO(), KeyCRUD, "No xattr cas found for key=%s, path=%s: %v", key, cfgXattrCasPath, xattrContErr) + DebugfCtx(ctx, KeyCRUD, "No xattr cas found for key=%s, path=%s: %v", key, cfgXattrCasPath, xattrContErr) return 0, ErrNotFound } cfgCas := HexCasToUint64(strCas) @@ -203,12 +203,12 @@ func (xbp *XattrBootstrapPersistence) loadConfig(c *gocb.Collection, key string, if bodyErr != nil { restoreErr := xbp.restoreDocumentBody(c, key, valuePtr, strCas) if restoreErr != nil { - WarnfCtx(context.TODO(), "Error attempting to restore unexpected deletion of config: %v", restoreErr) + WarnfCtx(ctx, "Error attempting to restore unexpected deletion of config: %v", restoreErr) } } return cfgCas, nil } else if errors.Is(lookupErr, gocbcore.ErrDocumentNotFound) { - DebugfCtx(context.TODO(), KeyCRUD, "No config document found for key=%s", key) + DebugfCtx(ctx, KeyCRUD, "No config document found for key=%s", key) return 0, ErrNotFound } else { return 0, lookupErr @@ -241,7 +241,7 @@ type DocumentBootstrapPersistence struct { CommonBootstrapPersistence } -func (dbp *DocumentBootstrapPersistence) loadRawConfig(c *gocb.Collection, key string) ([]byte, gocb.Cas, error) { +func (dbp *DocumentBootstrapPersistence) loadRawConfig(_ context.Context, c *gocb.Collection, key string) ([]byte, gocb.Cas, error) { res, err := c.Get(key, &gocb.GetOptions{ Transcoder: gocb.NewRawJSONTranscoder(), }) @@ -276,7 +276,7 @@ func (dbp *DocumentBootstrapPersistence) replaceRawConfig(c *gocb.Collection, ke return replaceRes.Cas(), nil } -func (dbp *DocumentBootstrapPersistence) loadConfig(c *gocb.Collection, key string, valuePtr interface{}) (cas uint64, err error) { +func (dbp *DocumentBootstrapPersistence) loadConfig(_ context.Context, c *gocb.Collection, key string, valuePtr interface{}) (cas uint64, err error) { res, err := c.Get(key, &gocb.GetOptions{ Timeout: time.Second * 10, diff --git a/base/config_persistence_test.go b/base/config_persistence_test.go index 5fc7ccf83c..a1c3d7dfc4 100644 --- a/base/config_persistence_test.go +++ b/base/config_persistence_test.go @@ -62,13 +62,13 @@ func TestConfigPersistence(t *testing.T) { _, reinsertErr := cp.insertConfig(c, configKey, configBody) require.Equal(t, ErrAlreadyExists, reinsertErr) + ctx := TestCtx(t) var loadedConfig map[string]interface{} - loadCas, loadErr := cp.loadConfig(c, configKey, &loadedConfig) + loadCas, loadErr := cp.loadConfig(ctx, c, configKey, &loadedConfig) require.NoError(t, loadErr) assert.Equal(t, insertCas, loadCas) assert.Equal(t, configBody["sampleConfig"], loadedConfig["sampleConfig"]) - - rawConfig, rawCas, rawErr := cp.loadRawConfig(c, configKey) + rawConfig, rawCas, rawErr := cp.loadRawConfig(ctx, c, configKey) require.NoError(t, rawErr) assert.Equal(t, insertCas, uint64(rawCas)) assert.Equal(t, rawConfigBody, rawConfig) @@ -87,13 +87,13 @@ func TestConfigPersistence(t *testing.T) { // retrieve config, validate updated value var updatedConfig map[string]interface{} - loadCas, loadErr = cp.loadConfig(c, configKey, &updatedConfig) + loadCas, loadErr = cp.loadConfig(ctx, c, configKey, &updatedConfig) require.NoError(t, loadErr) assert.Equal(t, updateCas, gocb.Cas(loadCas)) assert.Equal(t, configBody["updated"], updatedConfig["updated"]) // retrieve raw config, validate updated value - rawConfig, rawCas, rawErr = cp.loadRawConfig(c, configKey) + rawConfig, rawCas, rawErr = cp.loadRawConfig(ctx, c, configKey) require.NoError(t, rawErr) assert.Equal(t, updateCas, rawCas) assert.Equal(t, updatedRawBody, rawConfig) @@ -108,11 +108,11 @@ func TestConfigPersistence(t *testing.T) { // attempt to retrieve config, validate not found var deletedConfig map[string]interface{} - _, loadErr = cp.loadConfig(c, configKey, &deletedConfig) + _, loadErr = cp.loadConfig(ctx, c, configKey, &deletedConfig) assert.Equal(t, ErrNotFound, loadErr) // attempt to retrieve raw config, validate updated value - _, _, rawErr = cp.loadRawConfig(c, configKey) + _, _, rawErr = cp.loadRawConfig(ctx, c, configKey) require.Error(t, rawErr) }) } @@ -155,9 +155,10 @@ func TestXattrConfigPersistence(t *testing.T) { _, reinsertErr := cp.insertConfig(c, configKey, configBody) require.Equal(t, ErrAlreadyExists, reinsertErr) + ctx := TestCtx(t) // Retrieve the config, cas should still match insertCas var loadedConfig map[string]interface{} - loadCas, loadErr := cp.loadConfig(c, configKey, &loadedConfig) + loadCas, loadErr := cp.loadConfig(ctx, c, configKey, &loadedConfig) require.NoError(t, loadErr) assert.Equal(t, insertCas, loadCas) assert.Equal(t, configBody["sampleConfig"], loadedConfig["sampleConfig"]) @@ -167,7 +168,7 @@ func TestXattrConfigPersistence(t *testing.T) { require.NoError(t, err) // Retrieve the config, cas should still match insertCas - loadCas, loadErr = cp.loadConfig(c, configKey, &loadedConfig) + loadCas, loadErr = cp.loadConfig(ctx, c, configKey, &loadedConfig) require.NoError(t, loadErr) assert.Equal(t, insertCas, loadCas) assert.Equal(t, configBody["sampleConfig"], loadedConfig["sampleConfig"]) @@ -183,7 +184,7 @@ func TestXattrConfigPersistence(t *testing.T) { assert.NoError(t, deleteErr) // Retrieve the config, cas should still match insertCas - loadCas, loadErr = cp.loadConfig(c, configKey, &loadedConfig) + loadCas, loadErr = cp.loadConfig(ctx, c, configKey, &loadedConfig) require.NoError(t, loadErr) assert.Equal(t, insertCas, loadCas) assert.Equal(t, configBody["sampleConfig"], loadedConfig["sampleConfig"]) @@ -194,7 +195,7 @@ func TestXattrConfigPersistence(t *testing.T) { assert.True(t, docBody != nil) // Retrieve the config, cas should still match insertCas - loadCas, loadErr = cp.loadConfig(c, configKey, &loadedConfig) + loadCas, loadErr = cp.loadConfig(ctx, c, configKey, &loadedConfig) require.NoError(t, loadErr) assert.Equal(t, insertCas, loadCas) assert.Equal(t, configBody["sampleConfig"], loadedConfig["sampleConfig"]) diff --git a/base/dcp_client.go b/base/dcp_client.go index c49fec4451..8b7d35998d 100644 --- a/base/dcp_client.go +++ b/base/dcp_client.go @@ -40,6 +40,7 @@ type endStreamCallbackFunc func(e endStreamEvent) var ErrVbUUIDMismatch = errors.New("VbUUID mismatch when failOnRollback set") type DCPClient struct { + ctx context.Context ID string // unique ID for DCPClient - used for DCP stream name, must be unique agent *gocbcore.DCPAgent // SDK DCP agent, manages connections and calls back to DCPClient stream observer implementation callback sgbucket.FeedEventCallbackFunc // Callback invoked on DCP mutations/deletions @@ -79,7 +80,7 @@ type DCPClientOptions struct { CheckpointPrefix string } -func NewDCPClient(ID string, callback sgbucket.FeedEventCallbackFunc, options DCPClientOptions, bucket *GocbV2Bucket) (*DCPClient, error) { +func NewDCPClient(ctx context.Context, ID string, callback sgbucket.FeedEventCallbackFunc, options DCPClientOptions, bucket *GocbV2Bucket) (*DCPClient, error) { numWorkers := DefaultNumWorkers if options.NumWorkers > 0 { @@ -101,6 +102,7 @@ func NewDCPClient(ID string, callback sgbucket.FeedEventCallbackFunc, options DC } } client := &DCPClient{ + ctx: ctx, workers: make([]*DCPWorker, numWorkers), numVbuckets: numVbuckets, callback: callback, @@ -127,7 +129,7 @@ func NewDCPClient(ID string, callback sgbucket.FeedEventCallbackFunc, options DC case DCPMetadataStoreCS: // TODO: Change GetSingleDataStore to a metadata Store? metadataStore := bucket.DefaultDataStore() - client.metadata = NewDCPMetadataCS(metadataStore, numVbuckets, numWorkers, checkpointPrefix) + client.metadata = NewDCPMetadataCS(ctx, metadataStore, numVbuckets, numWorkers, checkpointPrefix) case DCPMetadataStoreInMemory: client.metadata = NewDCPMetadataMem(numVbuckets) default: @@ -255,7 +257,7 @@ func (dc *DCPClient) Start() (doneChan chan error, err error) { return dc.doneChannel, err } } - dc.startWorkers() + dc.startWorkers(dc.ctx) for i := uint16(0); i < dc.numVbuckets; i++ { openErr := dc.openStream(i, openRetryCount) @@ -287,7 +289,7 @@ func (dc *DCPClient) close() { // set dc.closing to true, avoid re-triggering close if it's already in progress if !dc.closing.CompareAndSwap(false, true) { - InfofCtx(context.TODO(), KeyDCP, "DCP Client close called - client is already closing") + InfofCtx(dc.ctx, KeyDCP, "DCP Client close called - client is already closing") return } @@ -296,7 +298,7 @@ func (dc *DCPClient) close() { if dc.agent != nil { agentErr := dc.agent.Close() if agentErr != nil { - WarnfCtx(context.TODO(), "Error closing DCP agent in client close: %v", agentErr) + WarnfCtx(dc.ctx, "Error closing DCP agent in client close: %v", agentErr) } } @@ -320,16 +322,16 @@ func (dc *DCPClient) initAgent(spec BucketSpec) error { } agentConfig := gocbcore.DCPAgentConfig{} - DebugfCtx(context.TODO(), KeyAll, "Parsing cluster connection string %q", UD(connStr)) + DebugfCtx(dc.ctx, KeyAll, "Parsing cluster connection string %q", UD(connStr)) beforeFromConnStr := time.Now() connStrError := agentConfig.FromConnStr(connStr) if connStrError != nil { return fmt.Errorf("Unable to start DCP Client - error building conn str: %v", connStrError) } if d := time.Since(beforeFromConnStr); d > FromConnStrWarningThreshold { - WarnfCtx(context.TODO(), "Parsed cluster connection string %q in: %v", UD(connStr), d) + WarnfCtx(dc.ctx, "Parsed cluster connection string %q in: %v", UD(connStr), d) } else { - DebugfCtx(context.TODO(), KeyAll, "Parsed cluster connection string %q in: %v", UD(connStr), d) + DebugfCtx(dc.ctx, KeyAll, "Parsed cluster connection string %q in: %v", UD(connStr), d) } auth, authErr := spec.GocbcoreAuthProvider() @@ -337,7 +339,7 @@ func (dc *DCPClient) initAgent(spec BucketSpec) error { return fmt.Errorf("Unable to start DCP Client - error creating authenticator: %w", authErr) } - tlsRootCAProvider, err := GoCBCoreTLSRootCAProvider(&spec.TLSSkipVerify, spec.CACertPath) + tlsRootCAProvider, err := GoCBCoreTLSRootCAProvider(dc.ctx, &spec.TLSSkipVerify, spec.CACertPath) if err != nil { return err } @@ -400,7 +402,7 @@ func (dc *DCPClient) workerForVbno(vbNo uint16) *DCPWorker { } // startWorkers initializes the DCP workers to receive stream events from eventFeed -func (dc *DCPClient) startWorkers() { +func (dc *DCPClient) startWorkers(ctx context.Context) { // vbuckets are assigned to workers as vbNo % NumWorkers. Create set of assigned vbuckets assignedVbs := make(map[int][]uint16) @@ -419,13 +421,12 @@ func (dc *DCPClient) startWorkers() { metaPersistFrequency: dc.checkpointPersistFrequency, } dc.workers[index] = NewDCPWorker(index, dc.metadata, dc.callback, dc.onStreamEnd, dc.terminator, nil, dc.checkpointPrefix, assignedVbs[index], options) - dc.workers[index].Start(&dc.workersWg) + dc.workers[index].Start(ctx, &dc.workersWg) } } func (dc *DCPClient) openStream(vbID uint16, maxRetries uint32) error { - logCtx := context.TODO() var openStreamErr error var attempts uint32 for { @@ -446,26 +447,26 @@ func (dc *DCPClient) openStream(vbID uint16, maxRetries uint32) error { switch { case errors.As(openStreamErr, &rollbackErr): if dc.failOnRollback { - InfofCtx(logCtx, KeyDCP, "Open stream for vbID %d failed due to rollback or range error, closing client based on failOnRollback=true", vbID) + InfofCtx(dc.ctx, KeyDCP, "Open stream for vbID %d failed due to rollback or range error, closing client based on failOnRollback=true", vbID) return fmt.Errorf("%w, failOnRollback requested", openStreamErr) } - InfofCtx(logCtx, KeyDCP, "Open stream for vbID %d failed due to rollback or range error, will roll back metadata and retry: %v", vbID, openStreamErr) + InfofCtx(dc.ctx, KeyDCP, "Open stream for vbID %d failed due to rollback or range error, will roll back metadata and retry: %v", vbID, openStreamErr) - dc.rollback(logCtx, vbID, rollbackErr.SeqNo) + dc.rollback(dc.ctx, vbID, rollbackErr.SeqNo) case errors.Is(openStreamErr, gocbcore.ErrMemdRangeError): err := fmt.Errorf("Invalid metadata out of range for vbID %d, err: %v metadata %+v, shutting down agent", vbID, openStreamErr, dc.metadata.GetMeta(vbID)) - WarnfCtx(logCtx, "%s", err) + WarnfCtx(dc.ctx, "%s", err) return err case errors.Is(openStreamErr, ErrVbUUIDMismatch): - WarnfCtx(logCtx, "Closing Stream for vbID: %d, %s", vbID, openStreamErr) + WarnfCtx(dc.ctx, "Closing Stream for vbID: %d, %s", vbID, openStreamErr) return openStreamErr case errors.Is(openStreamErr, gocbcore.ErrShutdown): - WarnfCtx(logCtx, "Closing stream for vbID %d, agent has been shut down", vbID) + WarnfCtx(dc.ctx, "Closing stream for vbID %d, agent has been shut down", vbID) return openStreamErr case errors.Is(openStreamErr, ErrTimeout): - DebugfCtx(logCtx, KeyDCP, "Timeout attempting to open stream for vb %d, will retry", vbID) + DebugfCtx(dc.ctx, KeyDCP, "Timeout attempting to open stream for vb %d, will retry", vbID) default: - WarnfCtx(logCtx, "Unknown error opening stream for vbID %d: %v", vbID, openStreamErr) + WarnfCtx(dc.ctx, "Unknown error opening stream for vbID %d: %v", vbID, openStreamErr) } if maxRetries == infiniteOpenStreamRetries { continue @@ -565,29 +566,28 @@ func (dc *DCPClient) deactivateVbucket(vbID uint16) { dc.close() // On successful one-shot feed completion, purge persisted checkpoints if dc.oneShot { - dc.metadata.Purge(len(dc.workers)) + dc.metadata.Purge(dc.ctx, len(dc.workers)) } } } func (dc *DCPClient) onStreamEnd(e endStreamEvent) { - logCtx := context.TODO() if e.err == nil { - DebugfCtx(logCtx, KeyDCP, "Stream (vb:%d) closed, all items streamed", e.vbID) + DebugfCtx(dc.ctx, KeyDCP, "Stream (vb:%d) closed, all items streamed", e.vbID) dc.deactivateVbucket(e.vbID) return } if errors.Is(e.err, gocbcore.ErrDCPStreamClosed) { - DebugfCtx(logCtx, KeyDCP, "Stream (vb:%d) closed by DCPClient", e.vbID) + DebugfCtx(dc.ctx, KeyDCP, "Stream (vb:%d) closed by DCPClient", e.vbID) dc.fatalError(fmt.Errorf("Stream (vb:%d) closed by DCPClient", e.vbID)) return } if errors.Is(e.err, gocbcore.ErrDCPStreamStateChanged) || errors.Is(e.err, gocbcore.ErrDCPStreamTooSlow) || errors.Is(e.err, gocbcore.ErrDCPStreamDisconnected) { - DebugfCtx(logCtx, KeyDCP, "Stream (vb:%d) ended with a known error, will reconnect. Reason: %s", e.vbID, e.err) + DebugfCtx(dc.ctx, KeyDCP, "Stream (vb:%d) ended with a known error, will reconnect. Reason: %s", e.vbID, e.err) } else { - InfofCtx(logCtx, KeyDCP, "Stream (vb:%d) ended with an unknown error, will reconnect. Reason: %s", e.vbID, e.err) + InfofCtx(dc.ctx, KeyDCP, "Stream (vb:%d) ended with an unknown error, will reconnect. Reason: %s", e.vbID, e.err) } retries := infiniteOpenStreamRetries if dc.oneShot { diff --git a/base/dcp_client_metadata.go b/base/dcp_client_metadata.go index 8a5542bac6..f78099aab3 100644 --- a/base/dcp_client_metadata.go +++ b/base/dcp_client_metadata.go @@ -57,11 +57,11 @@ type DCPMetadataStore interface { SetFailoverEntries(vbID uint16, entries []gocbcore.FailoverEntry) // Persist writes the metadata for the specified workerID and vbucket IDs to the backing store - Persist(workerID int, vbIDs []uint16) + Persist(ctx context.Context, workerID int, vbIDs []uint16) // Purge removes all metadata associated with the metadata store from the bucket. It does not remove the // in-memory metadata. - Purge(numWorkers int) + Purge(ctx context.Context, numWorkers int) // GetKeyPrefix will retrieve the key prefix used for metadata persistence GetKeyPrefix() string @@ -149,12 +149,12 @@ func (m *dcpMetadataBase) SetEndSeqNos(endSeqNos map[uint16]uint64) { } // Persist is no-op for in-memory metadata store -func (md *DCPMetadataMem) Persist(workerID int, vbIDs []uint16) { +func (md *DCPMetadataMem) Persist(_ context.Context, workerID int, vbIDs []uint16) { return } // Purge is no-op for in-memory metadata store -func (md *DCPMetadataMem) Purge(numWorkers int) { +func (md *DCPMetadataMem) Purge(_ context.Context, numWorkers int) { return } @@ -197,7 +197,7 @@ type DCPMetadataCS struct { dcpMetadataBase } -func NewDCPMetadataCS(store DataStore, numVbuckets uint16, numWorkers int, keyPrefix string) *DCPMetadataCS { +func NewDCPMetadataCS(ctx context.Context, store DataStore, numVbuckets uint16, numWorkers int, keyPrefix string) *DCPMetadataCS { m := &DCPMetadataCS{ dataStore: store, @@ -215,7 +215,7 @@ func NewDCPMetadataCS(store DataStore, numVbuckets uint16, numWorkers int, keyPr // Initialize any persisted metadata for i := 0; i < numWorkers; i++ { - m.load(i) + m.load(ctx, i) } return m @@ -225,7 +225,7 @@ func NewDCPMetadataCS(store DataStore, numVbuckets uint16, numWorkers int, keyPr // set that has been assigned to the worker. There's no synchronization on m.metadata - relies on DCP worker to // avoid read/write races on vbucket data. Calls to persist must be blocking on the worker goroutine, and vbuckets are // only assigned to a single worker -func (m *DCPMetadataCS) Persist(workerID int, vbIDs []uint16) { +func (m *DCPMetadataCS) Persist(ctx context.Context, workerID int, vbIDs []uint16) { meta := WorkerMetadata{} meta.DCPMeta = make(map[uint16]DCPMetadata) @@ -234,36 +234,34 @@ func (m *DCPMetadataCS) Persist(workerID int, vbIDs []uint16) { } err := m.dataStore.Set(m.getMetadataKey(workerID), 0, nil, meta) if err != nil { - InfofCtx(context.TODO(), KeyDCP, "Unable to persist DCP metadata: %v", err) + InfofCtx(ctx, KeyDCP, "Unable to persist DCP metadata: %v", err) } else { - TracefCtx(context.TODO(), KeyDCP, "Persisted metadata for worker %d: %v", workerID, meta) - // log.Printf("Persisted metadata for worker %d (%s): %v", workerID, m.getMetadataKey(workerID), meta) + TracefCtx(ctx, KeyDCP, "Persisted metadata for worker %d: %v", workerID, meta) } return } -func (m *DCPMetadataCS) load(workerID int) { +func (m *DCPMetadataCS) load(ctx context.Context, workerID int) { var meta WorkerMetadata _, err := m.dataStore.Get(m.getMetadataKey(workerID), &meta) if err != nil { if IsKeyNotFoundError(m.dataStore, err) { return } - InfofCtx(context.TODO(), KeyDCP, "Error loading persisted metadata - metadata will be reset for worker %d: %s", workerID, err) + InfofCtx(ctx, KeyDCP, "Error loading persisted metadata - metadata will be reset for worker %d: %s", workerID, err) } - TracefCtx(context.TODO(), KeyDCP, "Loaded metadata for worker %d: %v", workerID, meta) - // log.Printf("Loaded metadata for worker %d (%s): %v", workerID, m.getMetadataKey(workerID), meta) + TracefCtx(ctx, KeyDCP, "Loaded metadata for worker %d: %v", workerID, meta) for vbID, metadata := range meta.DCPMeta { m.metadata[vbID] = metadata } } -func (m *DCPMetadataCS) Purge(numWorkers int) { +func (m *DCPMetadataCS) Purge(ctx context.Context, numWorkers int) { for i := 0; i < numWorkers; i++ { err := m.dataStore.Delete(m.getMetadataKey(i)) if err != nil && !IsKeyNotFoundError(m.dataStore, err) { - InfofCtx(context.TODO(), KeyDCP, "Unable to remove DCP checkpoint for key %s: %v", m.getMetadataKey(i), err) + InfofCtx(ctx, KeyDCP, "Unable to remove DCP checkpoint for key %s: %v", m.getMetadataKey(i), err) } } } diff --git a/base/dcp_client_stream_observer.go b/base/dcp_client_stream_observer.go index bab699ee72..abfc4d3a72 100644 --- a/base/dcp_client_stream_observer.go +++ b/base/dcp_client_stream_observer.go @@ -9,8 +9,6 @@ package base import ( - "context" - "github.com/couchbase/gocbcore/v10" ) @@ -29,7 +27,7 @@ func (dc *DCPClient) SnapshotMarker(snapshotMarker gocbcore.DcpSnapshotMarker) { endSeq: snapshotMarker.EndSeqNo, snapshotType: snapshotMarker.SnapshotType, } - dc.workerForVbno(snapshotMarker.VbID).Send(e) + dc.workerForVbno(snapshotMarker.VbID).Send(dc.ctx, e) } func (dc *DCPClient) Mutation(mutation gocbcore.DcpMutation) { @@ -52,7 +50,7 @@ func (dc *DCPClient) Mutation(mutation gocbcore.DcpMutation) { key: mutation.Key, value: mutation.Value, } - dc.workerForVbno(mutation.VbID).Send(e) + dc.workerForVbno(mutation.VbID).Send(dc.ctx, e) } func (dc *DCPClient) Deletion(deletion gocbcore.DcpDeletion) { @@ -73,7 +71,7 @@ func (dc *DCPClient) Deletion(deletion gocbcore.DcpDeletion) { key: deletion.Key, value: deletion.Value, } - dc.workerForVbno(deletion.VbID).Send(e) + dc.workerForVbno(deletion.VbID).Send(dc.ctx, e) } @@ -85,14 +83,14 @@ func (dc *DCPClient) End(end gocbcore.DcpStreamEnd, err error) { streamID: end.StreamID, }, err: err} - dc.workerForVbno(end.VbID).Send(e) + dc.workerForVbno(end.VbID).Send(dc.ctx, e) } func (dc *DCPClient) Expiration(expiration gocbcore.DcpExpiration) { // SG doesn't opt in to expirations, so they'll come through as deletion events // (cf.https://github.com/couchbase/kv_engine/blob/master/docs/dcp/documentation/expiry-opcode-output.md) - WarnfCtx(context.TODO(), "Unexpected DCP expiration event (vb:%d) for key %v", expiration.VbID, UD(string(expiration.Key))) + WarnfCtx(dc.ctx, "Unexpected DCP expiration event (vb:%d) for key %v", expiration.VbID, UD(string(expiration.Key))) } func (dc *DCPClient) CreateCollection(creation gocbcore.DcpCollectionCreation) { @@ -124,7 +122,7 @@ func (dc *DCPClient) OSOSnapshot(snapshot gocbcore.DcpOSOSnapshot) { } func (dc *DCPClient) SeqNoAdvanced(seqNoAdvanced gocbcore.DcpSeqNoAdvanced) { - dc.workerForVbno(seqNoAdvanced.VbID).Send(seqnoAdvancedEvent{ + dc.workerForVbno(seqNoAdvanced.VbID).Send(dc.ctx, seqnoAdvancedEvent{ streamEventCommon: streamEventCommon{ vbID: seqNoAdvanced.VbID, streamID: seqNoAdvanced.StreamID, diff --git a/base/dcp_client_test.go b/base/dcp_client_test.go index 15647522e8..431f5b55ba 100644 --- a/base/dcp_client_test.go +++ b/base/dcp_client_test.go @@ -71,7 +71,7 @@ func TestOneShotDCP(t *testing.T) { gocbv2Bucket, err := AsGocbV2Bucket(bucket.Bucket) require.NoError(t, err) - dcpClient, err := NewDCPClient(feedID, counterCallback, clientOptions, gocbv2Bucket) + dcpClient, err := NewDCPClient(TestCtx(t), feedID, counterCallback, clientOptions, gocbv2Bucket) require.NoError(t, err) doneChan, startErr := dcpClient.Start() @@ -135,7 +135,7 @@ func TestTerminateDCPFeed(t *testing.T) { options := DCPClientOptions{ CheckpointPrefix: DefaultMetadataKeys.DCPCheckpointPrefix(t.Name()), } - dcpClient, err := NewDCPClient(feedID, counterCallback, options, gocbv2Bucket) + dcpClient, err := NewDCPClient(TestCtx(t), feedID, counterCallback, options, gocbv2Bucket) require.NoError(t, err) // Add documents in a separate goroutine @@ -245,9 +245,10 @@ func TestDCPClientMultiFeedConsistency(t *testing.T) { CheckpointPrefix: DefaultMetadataKeys.DCPCheckpointPrefix(t.Name()), } + ctx := TestCtx(t) gocbv2Bucket, err := AsGocbV2Bucket(bucket.Bucket) require.NoError(t, err) - dcpClient, err := NewDCPClient(feedID, counterCallback, dcpClientOpts, gocbv2Bucket) + dcpClient, err := NewDCPClient(ctx, feedID, counterCallback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) doneChan, startErr := dcpClient.Start() @@ -281,7 +282,7 @@ func TestDCPClientMultiFeedConsistency(t *testing.T) { CollectionIDs: collectionIDs, CheckpointPrefix: DefaultMetadataKeys.DCPCheckpointPrefix(t.Name()), } - dcpClient2, err := NewDCPClient(feedID, counterCallback, dcpClientOpts, gocbv2Bucket) + dcpClient2, err := NewDCPClient(ctx, feedID, counterCallback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) doneChan2, startErr2 := dcpClient2.Start() @@ -300,7 +301,7 @@ func TestDCPClientMultiFeedConsistency(t *testing.T) { CheckpointPrefix: DefaultMetadataKeys.DCPCheckpointPrefix(t.Name()), } - dcpClient3, err := NewDCPClient(feedID, counterCallback, dcpClientOpts, gocbv2Bucket) + dcpClient3, err := NewDCPClient(ctx, feedID, counterCallback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) doneChan3, startErr3 := dcpClient3.Start() @@ -371,7 +372,8 @@ func TestContinuousDCPRollback(t *testing.T) { // timeout for feed to complete timeout := time.After(20 * time.Second) - dcpClient, err := NewDCPClient(feedID, counterCallback, dcpClientOpts, gocbv2Bucket) + ctx := TestCtx(t) + dcpClient, err := NewDCPClient(ctx, feedID, counterCallback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) _, startErr := dcpClient.Start() @@ -406,7 +408,7 @@ func TestContinuousDCPRollback(t *testing.T) { } require.NoError(t, dcpClient.Close()) - dcpClient1, err := NewDCPClient(feedID, counterCallback, dcpClientOpts, gocbv2Bucket) + dcpClient1, err := NewDCPClient(ctx, feedID, counterCallback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) // function to force the rollback of some vBuckets dcpClient1.forceRollbackvBucket(vbUUID) @@ -497,10 +499,11 @@ func TestResumeStoppedFeed(t *testing.T) { CheckpointPrefix: DefaultMetadataKeys.DCPCheckpointPrefix(t.Name()), } + ctx := TestCtx(t) gocbv2Bucket, err := AsGocbV2Bucket(bucket.Bucket) require.NoError(t, err) - dcpClient, err = NewDCPClient(feedID, counterCallback, dcpClientOpts, gocbv2Bucket) + dcpClient, err = NewDCPClient(ctx, feedID, counterCallback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) doneChan, startErr := dcpClient.Start() @@ -534,7 +537,7 @@ func TestResumeStoppedFeed(t *testing.T) { CheckpointPrefix: DefaultMetadataKeys.DCPCheckpointPrefix(t.Name()), } - dcpClient2, err := NewDCPClient(feedID, secondCallback, dcpClientOpts, gocbv2Bucket) + dcpClient2, err := NewDCPClient(ctx, feedID, secondCallback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) doneChan2, startErr2 := dcpClient2.Start() @@ -577,7 +580,7 @@ func TestBadAgentPriority(t *testing.T) { gocbv2Bucket, err := AsGocbV2Bucket(bucket.Bucket) require.NoError(t, err) - dcpClient, err := NewDCPClient(feedID, panicCallback, dcpClientOpts, gocbv2Bucket) + dcpClient, err := NewDCPClient(TestCtx(t), feedID, panicCallback, dcpClientOpts, gocbv2Bucket) require.Error(t, err) require.Nil(t, dcpClient) } @@ -609,10 +612,11 @@ func TestDCPOutOfRangeSequence(t *testing.T) { // timeout for feed to complete timeout := time.After(20 * time.Second) + ctx := TestCtx(t) gocbv2Bucket, err := AsGocbV2Bucket(bucket) require.NoError(t, err) - dcpClient, err := NewDCPClient(feedID, callback, dcpClientOpts, gocbv2Bucket) + dcpClient, err := NewDCPClient(ctx, feedID, callback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) doneChan, startErr := dcpClient.Start() @@ -639,7 +643,7 @@ func TestDCPOutOfRangeSequence(t *testing.T) { InitialMetadata: metadata, } - dcpClient, err = NewDCPClient(feedID, callback, dcpClientOpts, gocbv2Bucket) + dcpClient, err = NewDCPClient(ctx, feedID, callback, dcpClientOpts, gocbv2Bucket) require.NoError(t, err) _, startErr = dcpClient.Start() diff --git a/base/dcp_client_worker.go b/base/dcp_client_worker.go index 7f47a49309..dd1fb1f5cb 100644 --- a/base/dcp_client_worker.go +++ b/base/dcp_client_worker.go @@ -84,24 +84,24 @@ func NewDCPWorker(workerID int, metadata DCPMetadataStore, mutationCallback sgbu } // Send accepts incoming events from the DCP client and adds to the worker's buffered feed, to be processed by the main worker goroutine -func (w *DCPWorker) Send(event streamEvent) { +func (w *DCPWorker) Send(ctx context.Context, event streamEvent) { // Ignore mutations if they come in after the client has started closing (CBG-2173) // This needs to be a separate select because if w.eventFeed has capacity at the same time as the terminator is closed, // the outcome is non-deterministic (https://go.dev/ref/spec#Select_statements) select { case <-w.terminator: - TracefCtx(context.TODO(), KeyDCP, "Ignoring stream event (vb:%d) as the client is closing", event.VbID()) + TracefCtx(ctx, KeyDCP, "Ignoring stream event (vb:%d) as the client is closing", event.VbID()) return default: } select { case w.eventFeed <- event: case <-w.terminator: - InfofCtx(context.TODO(), KeyDCP, "Closing DCP worker, DCP Client was closed") + InfofCtx(ctx, KeyDCP, "Closing DCP worker, DCP Client was closed") } } -func (w *DCPWorker) Start(wg *sync.WaitGroup) { +func (w *DCPWorker) Start(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) go func() { defer wg.Done() @@ -118,14 +118,14 @@ func (w *DCPWorker) Start(wg *sync.WaitGroup) { if w.mutationCallback != nil { w.mutationCallback(e.asFeedEvent()) } - w.updateSeq(e.key, vbID, e.seq) + w.updateSeq(ctx, e.key, vbID, e.seq) case deletionEvent: if w.mutationCallback != nil && !w.ignoreDeletes { w.mutationCallback(e.asFeedEvent()) } - w.updateSeq(e.key, vbID, e.seq) + w.updateSeq(ctx, e.key, vbID, e.seq) case seqnoAdvancedEvent: - w.updateSeq(nil, vbID, e.seq) + w.updateSeq(ctx, nil, vbID, e.seq) case endStreamEvent: w.endStreamCallback(e) } @@ -144,7 +144,7 @@ func (w *DCPWorker) checkPendingSnapshot(vbID uint16) { } } -func (w *DCPWorker) updateSeq(key []byte, vbID uint16, seq uint64) { +func (w *DCPWorker) updateSeq(ctx context.Context, key []byte, vbID uint16, seq uint64) { // Ignore DCP checkpoint documents if bytes.HasPrefix(key, w.checkpointPrefixBytes) { return @@ -155,7 +155,7 @@ func (w *DCPWorker) updateSeq(key []byte, vbID uint16, seq uint64) { w.metadata.UpdateSeq(vbID, seq) if time.Since(w.lastMetaPersistTime) > w.metaPersistFrequency { - w.metadata.Persist(w.ID, w.assignedVbs) + w.metadata.Persist(ctx, w.ID, w.assignedVbs) } } diff --git a/base/dcp_common.go b/base/dcp_common.go index 31109b3265..d321baec7d 100644 --- a/base/dcp_common.go +++ b/base/dcp_common.go @@ -252,7 +252,7 @@ func (c *DCPCommon) initMetadata(maxVbNo uint16) { defer c.m.Unlock() // Check for persisted backfill sequences - backfillSeqs, err := c.backfill.loadBackfillSequences(c.metaStore) + backfillSeqs, err := c.backfill.loadBackfillSequences(c.loggingCtx, c.metaStore) if err != nil { // Backfill sequences not present or invalid - will use metadata only backfillSeqs = nil @@ -318,7 +318,7 @@ func (c *DCPCommon) updateSeq(vbucketId uint16, seq uint64, warnOnLowerSeqNo boo // If in backfill, update backfill tracking if c.backfill.isActive() { - c.backfill.updateStats(vbucketId, previousSequence, c.seqs, c.metaStore) + c.backfill.updateStats(c.loggingCtx, vbucketId, previousSequence, c.seqs, c.metaStore) } } @@ -437,12 +437,11 @@ func (b *backfillStatus) snapshotStart(vbNo uint16, snapStart uint64, snapEnd ui b.snapStart[vbNo] = snapStart b.snapEnd[vbNo] = snapEnd } -func (b *backfillStatus) updateStats(vbno uint16, previousVbSequence uint64, currentSequences []uint64, datastore DataStore) { +func (b *backfillStatus) updateStats(ctx context.Context, vbno uint16, previousVbSequence uint64, currentSequences []uint64, datastore DataStore) { if !b.vbActive[vbno] { return } - logCtx := context.TODO() currentVbSequence := currentSequences[vbno] // Update backfill progress. If this vbucket has run past the end of the backfill, only include up to @@ -465,28 +464,28 @@ func (b *backfillStatus) updateStats(vbno uint16, previousVbSequence uint64, cur b.lastPersistTime = time.Now() err := b.persistBackfillSequences(datastore, currentSequences) if err != nil { - WarnfCtx(logCtx, "Error persisting back-fill sequences: %v", err) + WarnfCtx(ctx, "Error persisting back-fill sequences: %v", err) } - b.logBackfillProgress() + b.logBackfillProgress(ctx) } // If backfill is complete, log and do backfill inactivation/cleanup if b.receivedSequences >= b.expectedSequences { - InfofCtx(logCtx, KeyDCP, "Backfill complete") + InfofCtx(ctx, KeyDCP, "Backfill complete") b.active = false err := b.purgeBackfillSequences(datastore) if err != nil { - WarnfCtx(logCtx, "Error purging back-fill sequences: %v", err) + WarnfCtx(ctx, "Error purging back-fill sequences: %v", err) } } } // Logs current backfill progress. Expects caller to have the lock on r.m -func (b *backfillStatus) logBackfillProgress() { +func (b *backfillStatus) logBackfillProgress(ctx context.Context) { if !b.active { return } - InfofCtx(context.TODO(), KeyDCP, "Backfill in progress: %d%% (%d / %d)", int(b.receivedSequences*100/b.expectedSequences), b.receivedSequences, b.expectedSequences) + InfofCtx(ctx, KeyDCP, "Backfill in progress: %d%% (%d / %d)", int(b.receivedSequences*100/b.expectedSequences), b.receivedSequences, b.expectedSequences) } // BackfillSequences defines the format used to persist snapshot information to the _sync:dcp_backfill document @@ -506,13 +505,13 @@ func (b *backfillStatus) persistBackfillSequences(datastore DataStore, currentSe return datastore.Set(b.metaKeys.DCPBackfillKey(), 0, nil, backfillSeqs) } -func (b *backfillStatus) loadBackfillSequences(datastore DataStore) (*BackfillSequences, error) { +func (b *backfillStatus) loadBackfillSequences(ctx context.Context, datastore DataStore) (*BackfillSequences, error) { var backfillSeqs BackfillSequences _, err := datastore.Get(b.metaKeys.DCPBackfillKey(), &backfillSeqs) if err != nil { return nil, err } - InfofCtx(context.TODO(), KeyDCP, "Previously persisted backfill sequences found - will resume") + InfofCtx(ctx, KeyDCP, "Previously persisted backfill sequences found - will resume") return &backfillSeqs, nil } @@ -635,11 +634,11 @@ const ( ) // getNetworkTypeFromConnSpec returns the configured network type, or clusterNetworkAuto if nothing is defined. -func getNetworkTypeFromConnSpec(spec gocbconnstr.ConnSpec) clusterNetworkType { +func getNetworkTypeFromConnSpec(ctx context.Context, spec gocbconnstr.ConnSpec) clusterNetworkType { networkType := clusterNetworkAuto if networkOpt, ok := spec.Options["network"]; ok && len(networkOpt) > 0 { if len(networkOpt) > 1 { - WarnfCtx(context.TODO(), "multiple 'network' options found in connection string - using first one: %q", networkOpt[0]) + WarnfCtx(ctx, "multiple 'network' options found in connection string - using first one: %q", networkOpt[0]) } networkType = clusterNetworkType(networkOpt[0]) } diff --git a/base/dcp_dest.go b/base/dcp_dest.go index 2d15e25d81..fbe9ed1429 100644 --- a/base/dcp_dest.go +++ b/base/dcp_dest.go @@ -112,7 +112,7 @@ func (d *DCPDest) DataUpdate(partition string, key []byte, seq uint64, if !dcpKeyFilter(key, d.metaKeys) { return nil } - event := makeFeedEventForDest(key, val, cas, partitionToVbNo(partition), collectionIDFromExtras(extras), 0, 0, sgbucket.FeedOpMutation) + event := makeFeedEventForDest(key, val, cas, partitionToVbNo(d.loggingCtx, partition), collectionIDFromExtras(extras), 0, 0, sgbucket.FeedOpMutation) d.dataUpdate(seq, event) return nil } @@ -136,7 +136,7 @@ func (d *DCPDest) DataUpdateEx(partition string, key []byte, seq uint64, val []b if !ok { return errors.New("Unable to cast extras of type DEST_EXTRAS_TYPE_GOCB_DCP to cbgt.GocbExtras") } - event = makeFeedEventForDest(key, val, cas, partitionToVbNo(partition), dcpExtras.CollectionId, dcpExtras.Expiry, dcpExtras.Datatype, sgbucket.FeedOpMutation) + event = makeFeedEventForDest(key, val, cas, partitionToVbNo(d.loggingCtx, partition), dcpExtras.CollectionId, dcpExtras.Expiry, dcpExtras.Datatype, sgbucket.FeedOpMutation) } @@ -151,7 +151,7 @@ func (d *DCPDest) DataDelete(partition string, key []byte, seq uint64, return nil } - event := makeFeedEventForDest(key, nil, cas, partitionToVbNo(partition), collectionIDFromExtras(extras), 0, 0, sgbucket.FeedOpDeletion) + event := makeFeedEventForDest(key, nil, cas, partitionToVbNo(d.loggingCtx, partition), collectionIDFromExtras(extras), 0, 0, sgbucket.FeedOpDeletion) d.dataUpdate(seq, event) return nil } @@ -174,7 +174,7 @@ func (d *DCPDest) DataDeleteEx(partition string, key []byte, seq uint64, if !ok { return errors.New("Unable to cast extras of type DEST_EXTRAS_TYPE_GOCB_DCP to cbgt.GocbExtras") } - event = makeFeedEventForDest(key, dcpExtras.Value, cas, partitionToVbNo(partition), dcpExtras.CollectionId, dcpExtras.Expiry, dcpExtras.Datatype, sgbucket.FeedOpDeletion) + event = makeFeedEventForDest(key, dcpExtras.Value, cas, partitionToVbNo(d.loggingCtx, partition), dcpExtras.CollectionId, dcpExtras.Expiry, dcpExtras.Datatype, sgbucket.FeedOpDeletion) } d.dataUpdate(seq, event) @@ -183,12 +183,12 @@ func (d *DCPDest) DataDeleteEx(partition string, key []byte, seq uint64, func (d *DCPDest) SnapshotStart(partition string, snapStart, snapEnd uint64) error { - d.snapshotStart(partitionToVbNo(partition), snapStart, snapEnd) + d.snapshotStart(partitionToVbNo(d.loggingCtx, partition), snapStart, snapEnd) return nil } func (d *DCPDest) OpaqueGet(partition string) (value []byte, lastSeq uint64, err error) { - vbNo := partitionToVbNo(partition) + vbNo := partitionToVbNo(d.loggingCtx, partition) if !d.metaInitComplete[vbNo] { d.InitVbMeta(vbNo) d.metaInitComplete[vbNo] = true @@ -202,7 +202,7 @@ func (d *DCPDest) OpaqueGet(partition string) (value []byte, lastSeq uint64, err } func (d *DCPDest) OpaqueSet(partition string, value []byte) error { - vbNo := partitionToVbNo(partition) + vbNo := partitionToVbNo(d.loggingCtx, partition) if !d.metaInitComplete[vbNo] { d.InitVbMeta(vbNo) d.metaInitComplete[vbNo] = true @@ -212,12 +212,12 @@ func (d *DCPDest) OpaqueSet(partition string, value []byte) error { } func (d *DCPDest) Rollback(partition string, rollbackSeq uint64) error { - return d.rollback(partitionToVbNo(partition), rollbackSeq) + return d.rollback(partitionToVbNo(d.loggingCtx, partition), rollbackSeq) } func (d *DCPDest) RollbackEx(partition string, vbucketUUID uint64, rollbackSeq uint64) error { cbgtMeta := makeVbucketMetadataForSequence(vbucketUUID, rollbackSeq) - return d.rollbackEx(partitionToVbNo(partition), vbucketUUID, rollbackSeq, cbgtMeta) + return d.rollbackEx(partitionToVbNo(d.loggingCtx, partition), vbucketUUID, rollbackSeq, cbgtMeta) } // TODO: Not implemented, review potential usage @@ -243,10 +243,10 @@ func (d *DCPDest) Stats(io.Writer) error { return nil } -func partitionToVbNo(partition string) uint16 { +func partitionToVbNo(ctx context.Context, partition string) uint16 { vbNo, err := strconv.Atoi(partition) if err != nil { - ErrorfCtx(context.Background(), "Unexpected non-numeric partition value %s, ignoring: %v", partition, err) + ErrorfCtx(ctx, "Unexpected non-numeric partition value %s, ignoring: %v", partition, err) return 0 } return uint16(vbNo) diff --git a/base/dcp_receiver.go b/base/dcp_receiver.go index 37746a4528..3f2db0a71e 100644 --- a/base/dcp_receiver.go +++ b/base/dcp_receiver.go @@ -10,16 +10,11 @@ package base import ( "context" - "crypto/tls" - "errors" "expvar" - "github.com/couchbase/go-couchbase" "github.com/couchbase/go-couchbase/cbdatasource" "github.com/couchbase/gomemcached" sgbucket "github.com/couchbase/sg-bucket" - pkgerrors "github.com/pkg/errors" - "gopkg.in/couchbaselabs/gocbconnstr.v1" ) // Memcached binary protocol datatype bit flags (https://github.com/couchbase/memcached/blob/master/docs/BinaryProtocol.md#data-types), @@ -40,10 +35,10 @@ type DCPReceiver struct { *DCPCommon } -func NewDCPReceiver(callback sgbucket.FeedEventCallbackFunc, bucket Bucket, maxVbNo uint16, persistCheckpoints bool, dbStats *expvar.Map, feedID string, checkpointPrefix string, metaKeys *MetadataKeys) (cbdatasource.Receiver, context.Context, error) { +func NewDCPReceiver(ctx context.Context, callback sgbucket.FeedEventCallbackFunc, bucket Bucket, maxVbNo uint16, persistCheckpoints bool, dbStats *expvar.Map, feedID string, checkpointPrefix string, metaKeys *MetadataKeys) (cbdatasource.Receiver, context.Context, error) { metadataStore := bucket.DefaultDataStore() - dcpCommon, err := NewDCPCommon(context.TODO(), callback, bucket, metadataStore, maxVbNo, persistCheckpoints, dbStats, feedID, checkpointPrefix, metaKeys) + dcpCommon, err := NewDCPCommon(ctx, callback, bucket, metadataStore, maxVbNo, persistCheckpoints, dbStats, feedID, checkpointPrefix, metaKeys) if err != nil { return nil, nil, err } @@ -210,168 +205,3 @@ func (nph NoPasswordAuthHandler) GetCredentials() (username string, password str return "", "", bucketname } - -// This starts a cbdatasource powered DCP Feed using an entirely separate connection to Couchbase Server than anything the existing -// bucket is using, and it uses the go-couchbase cbdatasource DCP abstraction layer -func StartDCPFeed(bucket Bucket, spec BucketSpec, args sgbucket.FeedArguments, callback sgbucket.FeedEventCallbackFunc, dbStats *expvar.Map, metaKeys *MetadataKeys) error { - - connSpec, err := gocbconnstr.Parse(spec.Server) - if err != nil { - return err - } - - // Recommended usage of cbdatasource is to let it manage it's own dedicated connection, so we're not - // reusing the bucket connection we've already established. - urls, errConvertServerSpec := CouchbaseURIToHttpURL(bucket, spec.Server, &connSpec) - - if errConvertServerSpec != nil { - return errConvertServerSpec - } - - poolName := DefaultPool - bucketName := spec.BucketName - - vbucketIdsArr := []uint16(nil) // nil means get all the vbuckets. - - maxVbno, err := bucket.GetMaxVbno() - if err != nil { - return err - } - - persistCheckpoints := false - if args.Backfill == sgbucket.FeedResume { - persistCheckpoints = true - } - - feedID := args.ID - if feedID == "" { - InfofCtx(context.TODO(), KeyDCP, "DCP feed started without feedID specified - defaulting to %s", DCPCachingFeedID) - feedID = DCPCachingFeedID - } - receiver, loggingCtx, err := NewDCPReceiver(callback, bucket, maxVbno, persistCheckpoints, dbStats, feedID, args.CheckpointPrefix, metaKeys) - if err != nil { - return err - } - - var dcpReceiver *DCPReceiver - switch v := receiver.(type) { - case *DCPReceiver: - dcpReceiver = v - case *DCPLoggingReceiver: - dcpReceiver = v.rec - default: - return errors.New("NewDCPReceiver returned unexpected receiver implementation") - } - - // Initialize the feed based on the backfill type - _, feedInitErr := dcpReceiver.initFeed(args.Backfill) - if feedInitErr != nil { - return feedInitErr - } - - dataSourceOptions := CopyDefaultBucketDatasourceOptions() - if spec.UseXattrs { - dataSourceOptions.IncludeXAttrs = true - } - - dataSourceOptions.Logf = func(fmt string, v ...interface{}) { - DebugfCtx(loggingCtx, KeyDCP, fmt, v...) - } - - dataSourceOptions.Name, err = GenerateDcpStreamName(feedID) - InfofCtx(loggingCtx, KeyDCP, "DCP feed starting with name %s", dataSourceOptions.Name) - if err != nil { - return pkgerrors.Wrap(err, "unable to generate DCP stream name") - } - - auth := spec.Auth - - // If using client certificate for authentication, configure go-couchbase for cbdatasource's initial - // connection to retrieve cluster configuration. go-couchbase doesn't support handling - // x509 auth and root ca verification as separate concerns. - if spec.Certpath != "" && spec.Keypath != "" { - couchbase.SetCertFile(spec.Certpath) - couchbase.SetKeyFile(spec.Keypath) - auth = NoPasswordAuthHandler{Handler: spec.Auth} - couchbase.SetRootFile(spec.CACertPath) - couchbase.SetSkipVerify(false) - } - - if spec.IsTLS() { - dataSourceOptions.TLSConfig = func() *tls.Config { - return spec.TLSConfig() - } - } - - networkType := getNetworkTypeFromConnSpec(connSpec) - InfofCtx(loggingCtx, KeyDCP, "Using network type: %s", networkType) - - // default (aka internal) networking is handled by cbdatasource, so we can avoid the shims altogether in this case, for all other cases we need shims to remap hosts. - if networkType != clusterNetworkDefault { - // A lookup of host dest to external alternate address hostnames - dataSourceOptions.ConnectBucket, dataSourceOptions.Connect, dataSourceOptions.ConnectTLS = alternateAddressShims(loggingCtx, spec.IsTLS(), connSpec.Addresses, networkType) - } - - DebugfCtx(loggingCtx, KeyDCP, "Connecting to new bucket datasource. URLs:%s, pool:%s, bucket:%s", MD(urls), MD(poolName), MD(bucketName)) - - bds, err := cbdatasource.NewBucketDataSource( - urls, - poolName, - bucketName, - "", - vbucketIdsArr, - auth, - dcpReceiver, - dataSourceOptions, - ) - - if err != nil { - return pkgerrors.WithStack(RedactErrorf("Error connecting to new bucket cbdatasource. FeedID:%s URLs:%s, pool:%s, bucket:%s. Error: %v", feedID, MD(urls), MD(poolName), MD(bucketName), err)) - } - - if err = bds.Start(); err != nil { - return pkgerrors.WithStack(RedactErrorf("Error starting bucket cbdatasource. FeedID:%s URLs:%s, pool:%s, bucket:%s. Error: %v", feedID, MD(urls), MD(poolName), MD(bucketName), err)) - } - - // Close the data source if feed terminator is closed - if args.Terminator != nil { - go func() { - <-args.Terminator - TracefCtx(loggingCtx, KeyDCP, "Closing DCP Feed [%s-%s] based on termination notification", MD(bucketName), feedID) - if err := bds.Close(); err != nil { - DebugfCtx(loggingCtx, KeyDCP, "Error closing DCP Feed [%s-%s] based on termination notification, Error: %v", MD(bucketName), feedID, err) - } - if args.DoneChan != nil { - close(args.DoneChan) - } - }() - } - - return nil - -} - -// CopyDefaultBucketDatasourceOptions makes a copy of cbdatasource.DefaultBucketDataSourceOptions. -// DeepCopyInefficient can't be used here due to function definitions present on BucketDataSourceOptions (ConnectBucket, etc) -func CopyDefaultBucketDatasourceOptions() *cbdatasource.BucketDataSourceOptions { - return &cbdatasource.BucketDataSourceOptions{ - ClusterManagerBackoffFactor: cbdatasource.DefaultBucketDataSourceOptions.ClusterManagerBackoffFactor, - ClusterManagerSleepInitMS: cbdatasource.DefaultBucketDataSourceOptions.ClusterManagerSleepInitMS, - ClusterManagerSleepMaxMS: cbdatasource.DefaultBucketDataSourceOptions.ClusterManagerSleepMaxMS, - - DataManagerBackoffFactor: cbdatasource.DefaultBucketDataSourceOptions.DataManagerBackoffFactor, - DataManagerSleepInitMS: cbdatasource.DefaultBucketDataSourceOptions.DataManagerSleepInitMS, - DataManagerSleepMaxMS: cbdatasource.DefaultBucketDataSourceOptions.DataManagerSleepMaxMS, - - FeedBufferSizeBytes: cbdatasource.DefaultBucketDataSourceOptions.FeedBufferSizeBytes, - FeedBufferAckThreshold: cbdatasource.DefaultBucketDataSourceOptions.FeedBufferAckThreshold, - - NoopTimeIntervalSecs: cbdatasource.DefaultBucketDataSourceOptions.NoopTimeIntervalSecs, - - TraceCapacity: cbdatasource.DefaultBucketDataSourceOptions.TraceCapacity, - - PingTimeoutMS: cbdatasource.DefaultBucketDataSourceOptions.PingTimeoutMS, - - IncludeXAttrs: cbdatasource.DefaultBucketDataSourceOptions.IncludeXAttrs, - } -} diff --git a/base/dcp_sharded.go b/base/dcp_sharded.go index f0bd8b7559..1e2e3e73fa 100644 --- a/base/dcp_sharded.go +++ b/base/dcp_sharded.go @@ -89,7 +89,7 @@ func StartShardedDCPFeed(ctx context.Context, dbName string, configGroup string, // Register heartbeat listener to trigger removal from cfg when // other SG nodes stop sending heartbeats. - listener, err := registerHeartbeatListener(heartbeater, cbgtContext) + listener, err := registerHeartbeatListener(ctx, heartbeater, cbgtContext) if err != nil { return nil, err } @@ -169,11 +169,11 @@ func createCBGTIndex(ctx context.Context, c *CbgtContext, dbName string, configG cbgt.RegisterBucketDataSourceOptionsCallback(indexName, c.Manager.UUID(), func(options *cbdatasource.BucketDataSourceOptions) *cbdatasource.BucketDataSourceOptions { if spec.IsTLS() { options.TLSConfig = func() *tls.Config { - return spec.TLSConfig() + return spec.TLSConfig(ctx) } } - networkType := getNetworkTypeFromConnSpec(connSpec) + networkType := getNetworkTypeFromConnSpec(ctx, connSpec) InfofCtx(ctx, KeyDCP, "Using network type: %s", networkType) // default (aka internal) networking is handled by cbdatasource, so we can avoid the shims altogether in this case, for all other cases we need shims to remap hosts. @@ -376,7 +376,7 @@ func initCBGTManager(ctx context.Context, bucket Bucket, spec BucketSpec, cfgSG if spec.TLSSkipVerify { setCbgtRootCertsForBucket(bucketUUID, nil) } else { - certs, err := getRootCAs(spec.CACertPath) + certs, err := getRootCAs(ctx, spec.CACertPath) if err != nil { return nil, fmt.Errorf("failed to load root CAs: %w", err) } @@ -531,14 +531,14 @@ func initCfgCB(bucket Bucket, spec BucketSpec) (*cbgt.CfgCB, error) { return cfgCB, nil } -func registerHeartbeatListener(heartbeater Heartbeater, cbgtContext *CbgtContext) (*importHeartbeatListener, error) { +func registerHeartbeatListener(ctx context.Context, heartbeater Heartbeater, cbgtContext *CbgtContext) (*importHeartbeatListener, error) { if cbgtContext == nil || cbgtContext.Manager == nil || cbgtContext.Cfg == nil || heartbeater == nil { return nil, errors.New("Unable to register import heartbeat listener with nil manager, cfg or heartbeater") } // Register listener for import, uses cfg and manager to manage set of participating nodes - importHeartbeatListener, err := NewImportHeartbeatListener(cbgtContext) + importHeartbeatListener, err := NewImportHeartbeatListener(ctx, cbgtContext) if err != nil { return nil, err } @@ -561,16 +561,16 @@ type importHeartbeatListener struct { lock sync.RWMutex // lock for nodeIDs access } -func NewImportHeartbeatListener(ctx *CbgtContext) (*importHeartbeatListener, error) { +func NewImportHeartbeatListener(ctx context.Context, cbgtCtx *CbgtContext) (*importHeartbeatListener, error) { - if ctx == nil { + if cbgtCtx == nil { return nil, errors.New("ctx must not be nil for ImportHeartbeatListener") } listener := &importHeartbeatListener{ - ctx: ctx, - mgr: ctx.Manager, - cfg: ctx.Cfg, + ctx: cbgtCtx, + mgr: cbgtCtx.Manager, + cfg: cbgtCtx.Cfg, terminator: make(chan struct{}), } @@ -581,7 +581,7 @@ func NewImportHeartbeatListener(ctx *CbgtContext) (*importHeartbeatListener, err } // Subscribe to changes to the known node set key - err = listener.subscribeNodeChanges() + err = listener.subscribeNodeChanges(ctx) if err != nil { return nil, err } @@ -594,29 +594,28 @@ func (l *importHeartbeatListener) Name() string { } // When we detect other nodes have stopped pushing heartbeats, use manager to remove from cfg -func (l *importHeartbeatListener) StaleHeartbeatDetected(nodeUUID string) { +func (l *importHeartbeatListener) StaleHeartbeatDetected(ctx context.Context, nodeUUID string) { - InfofCtx(context.TODO(), KeyCluster, "StaleHeartbeatDetected by import listener for node: %v", nodeUUID) + InfofCtx(ctx, KeyCluster, "StaleHeartbeatDetected by import listener for node: %v", nodeUUID) err := cbgt.UnregisterNodes(l.cfg, l.mgr.Version(), []string{nodeUUID}) if err != nil { - WarnfCtx(context.TODO(), "Attempt to unregister %v from CBGT got error: %v", nodeUUID, err) + WarnfCtx(ctx, "Attempt to unregister %v from CBGT got error: %v", nodeUUID, err) } } // subscribeNodeChanges registers with the manager's cfg implementation for notifications on changes to the // NODE_DEFS_KNOWN key. When notified, refreshes the handlers nodeIDs. -func (l *importHeartbeatListener) subscribeNodeChanges() error { - logCtx := context.TODO() +func (l *importHeartbeatListener) subscribeNodeChanges(ctx context.Context) error { cfgEvents := make(chan cbgt.CfgEvent) err := l.cfg.Subscribe(cbgt.CfgNodeDefsKey(cbgt.NODE_DEFS_KNOWN), cfgEvents) if err != nil { - DebugfCtx(logCtx, KeyCluster, "Error subscribing NODE_DEFS_KNOWN changes: %v", err) + DebugfCtx(ctx, KeyCluster, "Error subscribing NODE_DEFS_KNOWN changes: %v", err) return err } err = l.cfg.Subscribe(cbgt.CfgNodeDefsKey(cbgt.NODE_DEFS_WANTED), cfgEvents) if err != nil { - DebugfCtx(logCtx, KeyCluster, "Error subscribing NODE_DEFS_WANTED changes: %v", err) + DebugfCtx(ctx, KeyCluster, "Error subscribing NODE_DEFS_WANTED changes: %v", err) return err } go func() { @@ -626,12 +625,12 @@ func (l *importHeartbeatListener) subscribeNodeChanges() error { case <-cfgEvents: localNodeRegistered, err := l.reloadNodes() if err != nil { - WarnfCtx(logCtx, "Error while reloading heartbeat node definitions: %v", err) + WarnfCtx(ctx, "Error while reloading heartbeat node definitions: %v", err) } if !localNodeRegistered { registerErr := l.mgr.Register(cbgt.NODE_DEFS_WANTED) if registerErr != nil { - WarnfCtx(logCtx, "Error attempting to re-register node, node will not participate in import until restarted or cbgt cfg is next updated: %v", registerErr) + WarnfCtx(ctx, "Error attempting to re-register node, node will not participate in import until restarted or cbgt cfg is next updated: %v", registerErr) } } diff --git a/base/dcp_test.go b/base/dcp_test.go index d07a4de0a1..f75dddfdf1 100644 --- a/base/dcp_test.go +++ b/base/dcp_test.go @@ -427,7 +427,7 @@ func TestConcurrentCBGTIndexCreation(t *testing.T) { testDBName := "testDB" // Use an bucket-backed cfg - cfg, err := NewCfgSG(dataStore, "") + cfg, err := NewCfgSG(TestCtx(t), dataStore, "") require.NoError(t, err) // Define index type for db name diff --git a/base/gocb_dcp_feed.go b/base/gocb_dcp_feed.go index 734fa75c9c..ac38aed4e8 100644 --- a/base/gocb_dcp_feed.go +++ b/base/gocb_dcp_feed.go @@ -98,6 +98,7 @@ func StartGocbDCPFeed(ctx context.Context, bucket *GocbV2Bucket, bucketName stri } dcpClient, err := NewDCPClient( + ctx, feedName, callback, options, diff --git a/base/gocb_utils.go b/base/gocb_utils.go index e0bc0c8a3b..dd1b77604a 100644 --- a/base/gocb_utils.go +++ b/base/gocb_utils.go @@ -21,10 +21,10 @@ import ( ) // GoCBv2SecurityConfig returns a gocb.SecurityConfig to use when connecting given a CA Cert path. -func GoCBv2SecurityConfig(tlsSkipVerify *bool, caCertPath string) (sc gocb.SecurityConfig, err error) { +func GoCBv2SecurityConfig(ctx context.Context, tlsSkipVerify *bool, caCertPath string) (sc gocb.SecurityConfig, err error) { var certPool *x509.CertPool = nil if tlsSkipVerify == nil || !*tlsSkipVerify { // Add certs if ServerTLSSkipVerify is not set - certPool, err = getRootCAs(caCertPath) + certPool, err = getRootCAs(ctx, caCertPath) if err != nil { return sc, err } @@ -121,10 +121,10 @@ func GoCBCoreAuthConfig(username, password, certPath, keyPath string) (gocbcore. }, nil } -func GoCBCoreTLSRootCAProvider(tlsSkipVerify *bool, caCertPath string) (wrapper func() *x509.CertPool, err error) { +func GoCBCoreTLSRootCAProvider(ctx context.Context, tlsSkipVerify *bool, caCertPath string) (wrapper func() *x509.CertPool, err error) { var certPool *x509.CertPool = nil if tlsSkipVerify == nil || !*tlsSkipVerify { // Add certs if ServerTLSSkipVerify is not set - certPool, err = getRootCAs(caCertPath) + certPool, err = getRootCAs(ctx, caCertPath) if err != nil { return nil, err } @@ -137,7 +137,7 @@ func GoCBCoreTLSRootCAProvider(tlsSkipVerify *bool, caCertPath string) (wrapper // getRootCAs gets generates a cert pool from the certs at caCertPath. If caCertPath is empty, the systems cert pool is used. // If an error happens when retrieving the system cert pool, it is logged (not returned) and an empty (not nil) cert pool is returned. -func getRootCAs(caCertPath string) (*x509.CertPool, error) { +func getRootCAs(ctx context.Context, caCertPath string) (*x509.CertPool, error) { if caCertPath != "" { rootCAs := x509.NewCertPool() @@ -157,7 +157,7 @@ func getRootCAs(caCertPath string) (*x509.CertPool, error) { rootCAs, err := x509.SystemCertPool() if err != nil { rootCAs = x509.NewCertPool() - WarnfCtx(context.Background(), "Could not retrieve root CAs: %v", err) + WarnfCtx(ctx, "Could not retrieve root CAs: %v", err) } return rootCAs, nil } diff --git a/base/gocb_utils_test.go b/base/gocb_utils_test.go index 20ade6c371..c0e29bdcfa 100644 --- a/base/gocb_utils_test.go +++ b/base/gocb_utils_test.go @@ -66,7 +66,7 @@ func TestGoCBv2SecurityConfig(t *testing.T) { // for _, test := range tests { t.Run(test.name, func(t *testing.T) { - sc, err := GoCBv2SecurityConfig(test.tlsSkipVerify, test.caCertPath) + sc, err := GoCBv2SecurityConfig(TestCtx(t), test.tlsSkipVerify, test.caCertPath) if test.expectError { assert.Error(t, err) assert.Nil(t, sc.TLSRootCAs) diff --git a/base/heartbeat.go b/base/heartbeat.go index 3132abe3f7..ce74fc1dfe 100644 --- a/base/heartbeat.go +++ b/base/heartbeat.go @@ -30,10 +30,10 @@ const ( type Heartbeater interface { RegisterListener(listener HeartbeatListener) error UnregisterListener(name string) - Start() error - StartSendingHeartbeats() error - StartCheckingHeartbeats() error - Stop() + Start(context.Context) error + StartSendingHeartbeats(context.Context) error + StartCheckingHeartbeats(context.Context) error + Stop(context.Context) } // A HeartbeatListener defines the set of nodes it wants to monitor, and a callback when one of those nodes stops @@ -41,7 +41,7 @@ type Heartbeater interface { type HeartbeatListener interface { Name() string GetNodes() (nodeUUIDs []string, err error) - StaleHeartbeatDetected(nodeUUID string) + StaleHeartbeatDetected(ctx context.Context, nodeUUID string) Stop() } @@ -103,17 +103,17 @@ func NewCouchbaseHeartbeater(dataStore DataStore, keyPrefix, nodeUUID string) (h // Start the heartbeater. Underlying methods performs the first heartbeat send and check synchronously, then // starts scheduled goroutines for ongoing processing. -func (h *couchbaseHeartBeater) Start() error { +func (h *couchbaseHeartBeater) Start(ctx context.Context) error { - if err := h.StartSendingHeartbeats(); err != nil { + if err := h.StartSendingHeartbeats(ctx); err != nil { return err } - if err := h.StartCheckingHeartbeats(); err != nil { + if err := h.StartCheckingHeartbeats(ctx); err != nil { return err } - DebugfCtx(context.TODO(), KeyCluster, "Sending node heartbeats at interval: %v", h.heartbeatSendInterval) + DebugfCtx(ctx, KeyCluster, "Sending node heartbeats at interval: %v", h.heartbeatSendInterval) return nil @@ -121,7 +121,7 @@ func (h *couchbaseHeartBeater) Start() error { // Stop terminates the send and check goroutines, and blocks for up to 1s // until goroutines are actually terminated. -func (h *couchbaseHeartBeater) Stop() { +func (h *couchbaseHeartBeater) Stop(ctx context.Context) { if h == nil { return @@ -134,7 +134,7 @@ func (h *couchbaseHeartBeater) Stop() { for h.sendActive.IsTrue() || h.checkActive.IsTrue() { waitTimeMs += 10 if waitTimeMs > maxWaitTimeMs { - WarnfCtx(context.Background(), "couchbaseHeartBeater didn't complete Stop() within expected elapsed time") + WarnfCtx(ctx, "couchbaseHeartBeater didn't complete Stop() within expected elapsed time") return } time.Sleep(10 * time.Millisecond) @@ -143,7 +143,7 @@ func (h *couchbaseHeartBeater) Stop() { } // Send initial heartbeat, and start goroutine to schedule sendHeartbeat invocation -func (h *couchbaseHeartBeater) StartSendingHeartbeats() error { +func (h *couchbaseHeartBeater) StartSendingHeartbeats(ctx context.Context) error { if err := h.sendHeartbeat(); err != nil { return err } @@ -161,7 +161,7 @@ func (h *couchbaseHeartBeater) StartSendingHeartbeats() error { return case <-ticker.C: if err := h.sendHeartbeat(); err != nil { - WarnfCtx(context.Background(), "Unexpected error sending heartbeat - will be retried: %v", err) + WarnfCtx(ctx, "Unexpected error sending heartbeat - will be retried: %v", err) } } } @@ -171,10 +171,10 @@ func (h *couchbaseHeartBeater) StartSendingHeartbeats() error { } // Perform initial heartbeat check, then start goroutine to schedule check for stale heartbeats -func (h *couchbaseHeartBeater) StartCheckingHeartbeats() error { +func (h *couchbaseHeartBeater) StartCheckingHeartbeats(ctx context.Context) error { - if err := h.checkStaleHeartbeats(); err != nil { - WarnfCtx(context.Background(), "Error checking for stale heartbeats: %v", err) + if err := h.checkStaleHeartbeats(ctx); err != nil { + WarnfCtx(ctx, "Error checking for stale heartbeats: %v", err) } ticker := time.NewTicker(h.heartbeatPollInterval) @@ -187,8 +187,8 @@ func (h *couchbaseHeartBeater) StartCheckingHeartbeats() error { ticker.Stop() return case <-ticker.C: - if err := h.checkStaleHeartbeats(); err != nil { - WarnfCtx(context.Background(), "Error checking for stale heartbeats: %v", err) + if err := h.checkStaleHeartbeats(ctx); err != nil { + WarnfCtx(ctx, "Error checking for stale heartbeats: %v", err) } } } @@ -246,13 +246,13 @@ func (l ListenerMap) String() string { // getAllNodes returns all nodes from all registered listeners as a map from nodeUUID to the listeners // registered for that node -func (h *couchbaseHeartBeater) getNodeListenerMap() ListenerMap { +func (h *couchbaseHeartBeater) getNodeListenerMap(ctx context.Context) ListenerMap { nodeToListenerMap := make(ListenerMap) h.heartbeatListenersMutex.RLock() for _, listener := range h.heartbeatListeners { listenerNodes, err := listener.GetNodes() if err != nil { - WarnfCtx(context.Background(), "Error obtaining node set for listener %s - will be omitted for this heartbeat iteration. Error: %v", listener.Name(), err) + WarnfCtx(ctx, "Error obtaining node set for listener %s - will be omitted for this heartbeat iteration. Error: %v", listener.Name(), err) } for _, nodeUUID := range listenerNodes { _, ok := nodeToListenerMap[nodeUUID] @@ -266,11 +266,11 @@ func (h *couchbaseHeartBeater) getNodeListenerMap() ListenerMap { return nodeToListenerMap } -func (h *couchbaseHeartBeater) checkStaleHeartbeats() error { +func (h *couchbaseHeartBeater) checkStaleHeartbeats(ctx context.Context) error { // Build set of all nodes - nodeListenerMap := h.getNodeListenerMap() - TracefCtx(context.Background(), KeyCluster, "Checking heartbeats for node set: %s", nodeListenerMap) + nodeListenerMap := h.getNodeListenerMap(ctx) + TracefCtx(ctx, KeyCluster, "Checking heartbeats for node set: %s", nodeListenerMap) for heartbeatNodeUUID, listeners := range nodeListenerMap { if heartbeatNodeUUID == h.nodeUUID { @@ -292,7 +292,7 @@ func (h *couchbaseHeartBeater) checkStaleHeartbeats() error { // doc not found, which means the heartbeat doc expired. // Notify listeners for this node for _, listener := range listeners { - listener.StaleHeartbeatDetected(heartbeatNodeUUID) + listener.StaleHeartbeatDetected(ctx, heartbeatNodeUUID) } } } @@ -390,8 +390,8 @@ func (dh *documentBackedListener) Stop() { return } -func (dh *documentBackedListener) StaleHeartbeatDetected(nodeUUID string) { - _ = dh.RemoveNode(nodeUUID) +func (dh *documentBackedListener) StaleHeartbeatDetected(ctx context.Context, nodeUUID string) { + _ = dh.RemoveNode(ctx, nodeUUID) atomic.AddUint64(&dh.staleNotificationCount, 1) } @@ -400,17 +400,17 @@ func (dh *documentBackedListener) StaleNotificationCount() uint64 { } // Adds the node to the tracking document -func (dh *documentBackedListener) AddNode(nodeID string) error { - return dh.updateNodeList(nodeID, false) +func (dh *documentBackedListener) AddNode(ctx context.Context, nodeID string) error { + return dh.updateNodeList(ctx, nodeID, false) } // Removes the node to the tracking document -func (dh *documentBackedListener) RemoveNode(nodeID string) error { - return dh.updateNodeList(nodeID, true) +func (dh *documentBackedListener) RemoveNode(ctx context.Context, nodeID string) error { + return dh.updateNodeList(ctx, nodeID, true) } // Adds or removes a nodeID from the node list document -func (dh *documentBackedListener) updateNodeList(nodeID string, remove bool) error { +func (dh *documentBackedListener) updateNodeList(ctx context.Context, nodeID string, remove bool) error { dh.lock.Lock() defer dh.lock.Unlock() @@ -445,7 +445,7 @@ func (dh *documentBackedListener) updateNodeList(nodeID string, remove bool) err dh.nodeIDs = append(dh.nodeIDs, nodeID) } - InfofCtx(context.TODO(), KeyCluster, "Updating nodeList document (%s) with node IDs: %v", dh.nodeListKey, dh.nodeIDs) + InfofCtx(ctx, KeyCluster, "Updating nodeList document (%s) with node IDs: %v", dh.nodeListKey, dh.nodeIDs) casOut, err := dh.datastore.WriteCas(dh.nodeListKey, 0, 0, dh.cas, dh.nodeIDs, 0) diff --git a/base/heartbeat_test.go b/base/heartbeat_test.go index f4a31fb685..4602924434 100644 --- a/base/heartbeat_test.go +++ b/base/heartbeat_test.go @@ -115,6 +115,7 @@ func TestCouchbaseHeartbeaters(t *testing.T) { nodeCount := 3 nodes := make([]*couchbaseHeartBeater, nodeCount) listeners := make([]*documentBackedListener, nodeCount) + ctx := TestCtx(t) for i := 0; i < nodeCount; i++ { nodeUUID := fmt.Sprintf("node%d", i) node, err := NewCouchbaseHeartbeater(dataStore, keyprefix, nodeUUID) @@ -124,13 +125,13 @@ func TestCouchbaseHeartbeaters(t *testing.T) { assert.NoError(t, node.SetExpirySeconds(2)) // Start node - assert.NoError(t, node.Start()) + assert.NoError(t, node.Start(ctx)) // Create and register listener. // Simulates service starting on node, and self-registering the nodeUUID to that listener's node set listener, err := NewDocumentBackedListener(dataStore, keyprefix) require.NoError(t, err) - assert.NoError(t, listener.AddNode(nodeUUID)) + assert.NoError(t, listener.AddNode(ctx, nodeUUID)) assert.NoError(t, node.RegisterListener(listener)) nodes[i] = node @@ -146,7 +147,7 @@ func TestCouchbaseHeartbeaters(t *testing.T) { assert.True(t, nodes[0].sendCount > 0) // Stop node 0 - nodes[0].Stop() + nodes[0].Stop(ctx) // Wait for another node to detect node0 has stopped sending heartbeats retryUntilFunc = func() bool { @@ -170,8 +171,8 @@ func TestCouchbaseHeartbeaters(t *testing.T) { assert.Contains(t, activeNodes, "node2") // Stop heartbeaters - nodes[1].Stop() - nodes[2].Stop() + nodes[1].Stop(ctx) + nodes[2].Stop(ctx) } // TestNewCouchbaseHeartbeater simulates three nodes, with two services (listeners). @@ -196,6 +197,7 @@ func TestCouchbaseHeartbeatersMultipleListeners(t *testing.T) { nodes := make([]*couchbaseHeartBeater, nodeCount) importListeners := make([]*documentBackedListener, nodeCount) sgrListeners := make([]*documentBackedListener, nodeCount) + ctx := TestCtx(t) for i := 0; i < nodeCount; i++ { nodeUUID := fmt.Sprintf("node%d", i) node, err := NewCouchbaseHeartbeater(dataStore, keyprefix, nodeUUID) @@ -205,13 +207,13 @@ func TestCouchbaseHeartbeatersMultipleListeners(t *testing.T) { assert.NoError(t, node.SetExpirySeconds(2)) // Start node - assert.NoError(t, node.Start()) + assert.NoError(t, node.Start(ctx)) // Create and register import listener on all nodes. // Simulates service starting on node, and self-registering the nodeUUID to that listener's node set importListener, err := NewDocumentBackedListener(dataStore, keyprefix+":import") require.NoError(t, err) - assert.NoError(t, importListener.AddNode(nodeUUID)) + assert.NoError(t, importListener.AddNode(ctx, nodeUUID)) assert.NoError(t, node.RegisterListener(importListener)) importListeners[i] = importListener @@ -219,7 +221,7 @@ func TestCouchbaseHeartbeatersMultipleListeners(t *testing.T) { if i < 2 { sgrListener, err := NewDocumentBackedListener(dataStore, keyprefix+":sgr") require.NoError(t, err) - assert.NoError(t, sgrListener.AddNode(nodeUUID)) + assert.NoError(t, sgrListener.AddNode(ctx, nodeUUID)) assert.NoError(t, node.RegisterListener(sgrListener)) sgrListeners[i] = sgrListener } @@ -236,7 +238,7 @@ func TestCouchbaseHeartbeatersMultipleListeners(t *testing.T) { assert.True(t, nodes[0].sendCount > 0) // Stop node 1 - nodes[0].Stop() + nodes[0].Stop(ctx) // Wait for both listener types node to detect node1 has stopped sending heartbeats retryUntilFunc = func() bool { @@ -276,8 +278,8 @@ func TestCouchbaseHeartbeatersMultipleListeners(t *testing.T) { assert.NotContains(t, activeReplicateNodes, "node2") // Stop heartbeaters - nodes[1].Stop() - nodes[2].Stop() + nodes[1].Stop(ctx) + nodes[2].Stop(ctx) } // TestNewCouchbaseHeartbeater simulates three nodes. The minimum time window for failed node @@ -326,9 +328,10 @@ func TestCBGTManagerHeartbeater(t *testing.T) { assert.NoError(t, node2.SetExpirySeconds(2)) assert.NoError(t, node3.SetExpirySeconds(2)) - assert.NoError(t, node1.Start()) - assert.NoError(t, node2.Start()) - assert.NoError(t, node3.Start()) + ctx := TestCtx(t) + assert.NoError(t, node1.Start(ctx)) + assert.NoError(t, node2.Start(ctx)) + assert.NoError(t, node3.Start(ctx)) // Create three heartbeat listeners, associate one with each node testUUID := cbgt.NewUUID() @@ -352,21 +355,21 @@ func TestCBGTManagerHeartbeater(t *testing.T) { "some-datasource", eventHandlers, options) - listener1, err := NewImportHeartbeatListener(&CbgtContext{ + listener1, err := NewImportHeartbeatListener(ctx, &CbgtContext{ Cfg: cfgCB, Manager: testManager, }) assert.NoError(t, err) assert.NoError(t, node1.RegisterListener(listener1)) - listener2, err := NewImportHeartbeatListener(&CbgtContext{ + listener2, err := NewImportHeartbeatListener(ctx, &CbgtContext{ Cfg: cfgCB, Manager: testManager, }) assert.NoError(t, err) assert.NoError(t, node2.RegisterListener(listener2)) - listener3, err := NewImportHeartbeatListener(&CbgtContext{ + listener3, err := NewImportHeartbeatListener(ctx, &CbgtContext{ Cfg: cfgCB, Manager: testManager, }) @@ -382,7 +385,7 @@ func TestCBGTManagerHeartbeater(t *testing.T) { assert.True(t, node1.sendCount > 0) // Stop node 1 - node1.Stop() + node1.Stop(ctx) // Wait for another node to detect node1 has stopped sending heartbeats retryUntilFunc = func() bool { @@ -404,6 +407,6 @@ func TestCBGTManagerHeartbeater(t *testing.T) { assert.Contains(t, activeNodes, "node3") // Stop heartbeaters - node2.Stop() - node3.Stop() + node2.Stop(ctx) + node3.Stop(ctx) } diff --git a/base/http_listener.go b/base/http_listener.go index ceefbcecb0..7fe7df0139 100644 --- a/base/http_listener.go +++ b/base/http_listener.go @@ -34,7 +34,7 @@ const ( // This is like a combination of http.ListenAndServe and http.ListenAndServeTLS, which also // uses ThrottledListen to limit the number of open HTTP connections. -func ListenAndServeHTTP(addr string, connLimit uint, certFile, keyFile string, handler http.Handler, +func ListenAndServeHTTP(ctx context.Context, addr string, connLimit uint, certFile, keyFile string, handler http.Handler, readTimeout, writeTimeout, readHeaderTimeout, idleTimeout time.Duration, http2Enabled bool, tlsMinVersion uint16) (serveFn func() error, server *http.Server, err error) { var config *tls.Config @@ -46,7 +46,7 @@ func ListenAndServeHTTP(addr string, connLimit uint, certFile, keyFile string, h protocolsEnabled = []string{"h2", "http/1.1"} } config.NextProtos = protocolsEnabled - InfofCtx(context.TODO(), KeyHTTP, "Protocols enabled: %v on %v", config.NextProtos, SD(addr)) + InfofCtx(ctx, KeyHTTP, "Protocols enabled: %v on %v", config.NextProtos, SD(addr)) config.Certificates = make([]tls.Certificate, 1) var err error config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) @@ -58,7 +58,7 @@ func ListenAndServeHTTP(addr string, connLimit uint, certFile, keyFile string, h // Callback that turns off TCP NODELAY option when a client transitions to a WebSocket: connStateFunc := func(clientConn net.Conn, state http.ConnState) { if state == http.StateHijacked { - turnOffNoDelay(context.Background(), clientConn) + turnOffNoDelay(ctx, clientConn) } } diff --git a/base/log_keys.go b/base/log_keys.go index a69d1166c8..b89947b870 100644 --- a/base/log_keys.go +++ b/base/log_keys.go @@ -191,7 +191,7 @@ func (keyMask *LogKeyMask) EnabledLogKeys() []string { // ToLogKey takes a slice of case-sensitive log key names and will return a LogKeyMask bitfield // and a slice of deferred log functions for any warnings that may occurr. -func ToLogKey(keysStr []string) (logKeys LogKeyMask) { +func ToLogKey(ctx context.Context, keysStr []string) (logKeys LogKeyMask) { for _, key := range keysStr { // Take a copy of key, so we can use it in a closure outside the scope @@ -209,14 +209,14 @@ func ToLogKey(keysStr []string) (logKeys LogKeyMask) { // Strip a single "+" suffix in log keys and warn (for backwards compatibility) if strings.HasSuffix(key, "+") { newLogKey := strings.TrimSuffix(key, "+") - WarnfCtx(context.Background(), "Deprecated log key: %q found. Changing to: %q.", originalKey, newLogKey) + WarnfCtx(ctx, "Deprecated log key: %q found. Changing to: %q.", originalKey, newLogKey) key = newLogKey } if logKey, ok := logKeyNamesInverse[key]; ok { logKeys.Enable(logKey) } else { - WarnfCtx(context.Background(), "Invalid log key: %v", originalKey) + WarnfCtx(ctx, "Invalid log key: %v", originalKey) } } diff --git a/base/log_keys_test.go b/base/log_keys_test.go index 8650c9dd4b..26c45dc243 100644 --- a/base/log_keys_test.go +++ b/base/log_keys_test.go @@ -58,34 +58,35 @@ func TestLogKeyNames(t *testing.T) { assert.Contains(t, name, "Replicate") keys := []string{} - logKeys := ToLogKey(keys) + ctx := TestCtx(t) + logKeys := ToLogKey(ctx, keys) assert.Equal(t, LogKeyMask(0), logKeys) assert.Equal(t, []string{}, logKeys.EnabledLogKeys()) keys = append(keys, "DCP") - logKeys = ToLogKey(keys) + logKeys = ToLogKey(ctx, keys) assert.Equal(t, *logKeyMask(KeyDCP), logKeys) assert.Equal(t, []string{KeyDCP.String()}, logKeys.EnabledLogKeys()) keys = append(keys, "Access") - logKeys = ToLogKey(keys) + logKeys = ToLogKey(ctx, keys) assert.Equal(t, *logKeyMask(KeyAccess, KeyDCP), logKeys) assert.Equal(t, []string{KeyAccess.String(), KeyDCP.String()}, logKeys.EnabledLogKeys()) keys = []string{"*", "DCP"} - logKeys = ToLogKey(keys) + logKeys = ToLogKey(ctx, keys) assert.Equal(t, *logKeyMask(KeyAll, KeyDCP), logKeys) assert.Equal(t, []string{KeyAll.String(), KeyDCP.String()}, logKeys.EnabledLogKeys()) // Special handling of log keys keys = []string{"HTTP+"} - logKeys = ToLogKey(keys) + logKeys = ToLogKey(ctx, keys) assert.Equal(t, *logKeyMask(KeyHTTP, KeyHTTPResp), logKeys) assert.Equal(t, []string{KeyHTTP.String(), KeyHTTPResp.String()}, logKeys.EnabledLogKeys()) // Test that invalid log keys are ignored, and "+" suffixes are stripped. keys = []string{"DCP", "WS+", "InvalidLogKey"} - logKeys = ToLogKey(keys) + logKeys = ToLogKey(ctx, keys) assert.Equal(t, *logKeyMask(KeyDCP, KeyWebSocket), logKeys) assert.Equal(t, []string{KeyDCP.String(), KeyWebSocket.String()}, logKeys.EnabledLogKeys()) } @@ -190,8 +191,9 @@ func BenchmarkLogKeyName(b *testing.B) { } func BenchmarkToLogKey(b *testing.B) { + ctx := TestCtx(b) for i := 0; i < b.N; i++ { - _ = ToLogKey([]string{"CRUD", "DCP", "Replicate"}) + _ = ToLogKey(ctx, []string{"CRUD", "DCP", "Replicate"}) } } diff --git a/base/logger_console.go b/base/logger_console.go index c965b751d4..c57674d6fc 100644 --- a/base/logger_console.go +++ b/base/logger_console.go @@ -60,7 +60,7 @@ func NewConsoleLogger(ctx context.Context, shouldLogLocation bool, config *Conso return nil, err } - logKey := ToLogKey(config.LogKeys) + logKey := ToLogKey(ctx, config.LogKeys) logger := &ConsoleLogger{ LogLevel: config.LogLevel, diff --git a/base/logger_file.go b/base/logger_file.go index 98535b81a7..b93a470587 100644 --- a/base/logger_file.go +++ b/base/logger_file.go @@ -76,14 +76,14 @@ type logRotationConfig struct { } // NewFileLogger returns a new FileLogger from a config. -func NewFileLogger(config *FileLoggerConfig, level LogLevel, name string, logFilePath string, minAge int, buffer *strings.Builder) (*FileLogger, error) { +func NewFileLogger(ctx context.Context, config *FileLoggerConfig, level LogLevel, name string, logFilePath string, minAge int, buffer *strings.Builder) (*FileLogger, error) { if config == nil { config = &FileLoggerConfig{} } // validate and set defaults - if err := config.init(level, name, logFilePath, minAge); err != nil { + if err := config.init(ctx, level, name, logFilePath, minAge); err != nil { return nil, err } @@ -174,7 +174,7 @@ func (l *FileLogger) getFileLoggerConfig() *FileLoggerConfig { return &fileLoggerConfig } -func (lfc *FileLoggerConfig) init(level LogLevel, name string, logFilePath string, minAge int) error { +func (lfc *FileLoggerConfig) init(ctx context.Context, level LogLevel, name string, logFilePath string, minAge int) error { if lfc == nil { return errors.New("nil LogFileConfig") } @@ -209,15 +209,15 @@ func (lfc *FileLoggerConfig) init(level LogLevel, name string, logFilePath strin go func() { defer func() { if panicked := recover(); panicked != nil { - WarnfCtx(context.Background(), "Panic when deleting rotated log files: \n %s", panicked, debug.Stack()) + WarnfCtx(ctx, "Panic when deleting rotated log files: \n %s", panicked, debug.Stack()) } }() for { select { case <-ticker.C: - err := runLogDeletion(logFilePath, level.String(), int(float64(*lfc.Rotation.RotatedLogsSizeLimit)*rotatedLogsLowWatermarkMultiplier), *lfc.Rotation.RotatedLogsSizeLimit) + err := runLogDeletion(ctx, logFilePath, level.String(), int(float64(*lfc.Rotation.RotatedLogsSizeLimit)*rotatedLogsLowWatermarkMultiplier), *lfc.Rotation.RotatedLogsSizeLimit) if err != nil { - WarnfCtx(context.Background(), "%s", err) + WarnfCtx(ctx, "%s", err) } } } @@ -267,7 +267,7 @@ func newLumberjackOutput(filename string, maxSize, maxAge int) *lumberjack.Logge // runLogDeletion will delete rotated logs for the supplied logLevel. It will only perform these deletions when the // cumulative size of the logs are above the supplied sizeLimitMB. // logDirectory is the supplied directory where the logs are stored. -func runLogDeletion(logDirectory string, logLevel string, sizeLimitMBLowWatermark int, sizeLimitMBHighWatermark int) (err error) { +func runLogDeletion(ctx context.Context, logDirectory string, logLevel string, sizeLimitMBLowWatermark int, sizeLimitMBHighWatermark int) (err error) { sizeLimitMBLowWatermark = sizeLimitMBLowWatermark * 1024 * 1024 // Convert MB input to bytes sizeLimitMBHighWatermark = sizeLimitMBHighWatermark * 1024 * 1024 // Convert MB input to bytes @@ -288,7 +288,7 @@ func runLogDeletion(logDirectory string, logLevel string, sizeLimitMBLowWatermar if strings.HasPrefix(file.Name(), logFilePrefix+logLevel) && strings.HasSuffix(file.Name(), ".log.gz") { fi, err := file.Info() if err != nil { - InfofCtx(context.TODO(), KeyAll, "Couldn't get size of log file %q: %v - ignoring for cleanup calculation", file.Name(), err) + InfofCtx(ctx, KeyAll, "Couldn't get size of log file %q: %v - ignoring for cleanup calculation", file.Name(), err) continue } diff --git a/base/logger_file_test.go b/base/logger_file_test.go index 290a9b4a02..168dcb3835 100644 --- a/base/logger_file_test.go +++ b/base/logger_file_test.go @@ -128,9 +128,10 @@ func TestRotatedLogDeletion(t *testing.T) { assert.NoError(t, err) err = makeTestFile(2, logFilePrefix+"info-2019-02-02T12-10-00.log.gz", dir) assert.NoError(t, err) - err = runLogDeletion(dir, "error", 3, 5) + ctx := TestCtx(t) + err = runLogDeletion(ctx, dir, "error", 3, 5) assert.NoError(t, err) - err = runLogDeletion(dir, "info", 5, 7) + err = runLogDeletion(ctx, dir, "info", 5, 7) assert.NoError(t, err) dirContents, err = os.ReadDir(dir) require.NoError(t, err) @@ -152,7 +153,7 @@ func TestRotatedLogDeletion(t *testing.T) { dir = t.TempDir() err = makeTestFile(3, logFilePrefix+"error.log.gz", dir) assert.NoError(t, err) - err = runLogDeletion(dir, "error", 2, 4) + err = runLogDeletion(ctx, dir, "error", 2, 4) assert.NoError(t, err) dirContents, err = os.ReadDir(dir) require.NoError(t, err) @@ -163,7 +164,7 @@ func TestRotatedLogDeletion(t *testing.T) { dir = t.TempDir() err = makeTestFile(5, logFilePrefix+"error.log.gz", dir) assert.NoError(t, err) - err = runLogDeletion(dir, "error", 2, 4) + err = runLogDeletion(ctx, dir, "error", 2, 4) assert.NoError(t, err) dirContents, err = os.ReadDir(dir) require.NoError(t, err) @@ -174,7 +175,7 @@ func TestRotatedLogDeletion(t *testing.T) { dir = t.TempDir() err = makeTestFile(1, logFilePrefix+"error.log.gz", dir) assert.NoError(t, err) - err = runLogDeletion(dir, "error", 2, 4) + err = runLogDeletion(ctx, dir, "error", 2, 4) assert.NoError(t, err) dirContents, err = os.ReadDir(dir) require.NoError(t, err) @@ -191,7 +192,7 @@ func TestRotatedLogDeletion(t *testing.T) { assert.NoError(t, err) err = makeTestFile(1, logFilePrefix+"error-2019-01-01T12-00-00.log.gz", dir) assert.NoError(t, err) - err = runLogDeletion(dir, "error", 2, 3) + err = runLogDeletion(ctx, dir, "error", 2, 3) assert.NoError(t, err) dirContents, err = os.ReadDir(dir) diff --git a/base/logging.go b/base/logging.go index 919c3f76b1..00366253bc 100644 --- a/base/logging.go +++ b/base/logging.go @@ -34,7 +34,7 @@ func GetLogKeys() map[string]bool { } // UpdateLogKeys updates the console's log keys from a map -func UpdateLogKeys(keys map[string]bool, replace bool) { +func UpdateLogKeys(ctx context.Context, keys map[string]bool, replace bool) { if replace { ConsoleLogKey().Set(logKeyMask(KeyNone)) } @@ -48,7 +48,7 @@ func UpdateLogKeys(keys map[string]bool, replace bool) { } } - InfofCtx(context.Background(), KeyAll, "Setting log keys to: %v", ConsoleLogKey().EnabledLogKeys()) + InfofCtx(ctx, KeyAll, "Setting log keys to: %v", ConsoleLogKey().EnabledLogKeys()) } // Returns a string identifying a function on the call stack. @@ -94,8 +94,8 @@ var ( ) // RotateLogfiles rotates all active log files. -func RotateLogfiles() map[*FileLogger]error { - InfofCtx(context.Background(), KeyAll, "Rotating log files...") +func RotateLogfiles(ctx context.Context) map[*FileLogger]error { + InfofCtx(ctx, KeyAll, "Rotating log files...") loggers := map[*FileLogger]error{ traceLogger: nil, @@ -272,7 +272,7 @@ func LogSyncGatewayVersion(ctx context.Context) { ConsolefCtx(ctx, LevelNone, KeyNone, msg) // Log the startup indicator to ALL log files too. - msg = addPrefixes(msg, context.Background(), LevelNone, KeyNone) + msg = addPrefixes(msg, ctx, LevelNone, KeyNone) if errorLogger.shouldLog(LevelNone) { errorLogger.logger.Printf(msg) } @@ -399,7 +399,7 @@ func AssertLogContains(t *testing.T, s string, f func()) { } return true, nil, nil } - err, _ := RetryLoop("wait for logs", retry, CreateSleeperFunc(10, 100)) + err, _ := RetryLoop(TestCtx(t), "wait for logs", retry, CreateSleeperFunc(10, 100)) consoleLogger.logger.SetOutput(os.Stderr) assert.NoError(t, err, "Console logs did not contain %q", s) diff --git a/base/logging_config.go b/base/logging_config.go index 46a6ae968a..4d84a24ded 100644 --- a/base/logging_config.go +++ b/base/logging_config.go @@ -89,33 +89,33 @@ func InitLogging(ctx context.Context, logFilePath string, ConsolefCtx(ctx, LevelInfo, KeyNone, "Logging: Files to %v", logFilePath) } - errorLogger, err = NewFileLogger(error, LevelError, LevelError.String(), logFilePath, errorMinAge, &errorLogger.buffer) + errorLogger, err = NewFileLogger(ctx, error, LevelError, LevelError.String(), logFilePath, errorMinAge, &errorLogger.buffer) if err != nil { return err } - warnLogger, err = NewFileLogger(warn, LevelWarn, LevelWarn.String(), logFilePath, warnMinAge, &warnLogger.buffer) + warnLogger, err = NewFileLogger(ctx, warn, LevelWarn, LevelWarn.String(), logFilePath, warnMinAge, &warnLogger.buffer) if err != nil { return err } - infoLogger, err = NewFileLogger(info, LevelInfo, LevelInfo.String(), logFilePath, infoMinAge, &infoLogger.buffer) + infoLogger, err = NewFileLogger(ctx, info, LevelInfo, LevelInfo.String(), logFilePath, infoMinAge, &infoLogger.buffer) if err != nil { return err } - debugLogger, err = NewFileLogger(debug, LevelDebug, LevelDebug.String(), logFilePath, debugMinAge, &debugLogger.buffer) + debugLogger, err = NewFileLogger(ctx, debug, LevelDebug, LevelDebug.String(), logFilePath, debugMinAge, &debugLogger.buffer) if err != nil { return err } - traceLogger, err = NewFileLogger(trace, LevelTrace, LevelTrace.String(), logFilePath, traceMinAge, &traceLogger.buffer) + traceLogger, err = NewFileLogger(ctx, trace, LevelTrace, LevelTrace.String(), logFilePath, traceMinAge, &traceLogger.buffer) if err != nil { return err } // Since there is no level checking in the stats logging, use LevelNone for the level. - statsLogger, err = NewFileLogger(stats, LevelNone, "stats", logFilePath, statsMinage, &statsLogger.buffer) + statsLogger, err = NewFileLogger(ctx, stats, LevelNone, "stats", logFilePath, statsMinage, &statsLogger.buffer) if err != nil { return err } diff --git a/base/logging_context_test.go b/base/logging_context_test.go index 40386c9533..33e3cf7b6a 100644 --- a/base/logging_context_test.go +++ b/base/logging_context_test.go @@ -51,7 +51,7 @@ func requireLogIs(t testing.TB, s string, f func()) { } return true, nil, nil } - err, _ := RetryLoop("wait for logs", retry, CreateSleeperFunc(10, 100)) + err, _ := RetryLoop(TestCtx(t), "wait for logs", retry, CreateSleeperFunc(10, 100)) require.NoError(t, err, "Console logs did not contain %q, got %q", s, originalLog) } diff --git a/base/logging_test.go b/base/logging_test.go index 1e91ec410e..0af3645624 100644 --- a/base/logging_test.go +++ b/base/logging_test.go @@ -98,7 +98,8 @@ func BenchmarkLogRotation(b *testing.B) { // Tidy up temp log files in a retry loop because // we can't remove temp dir while the async compression is still writing log files assert.NoError(bm, logger.Close()) - err, _ = RetryLoop("benchmark-logrotate-teardown", + ctx := TestCtx(bm) + err, _ = RetryLoop(ctx, "benchmark-logrotate-teardown", func() (shouldRetry bool, err error, value interface{}) { err = os.RemoveAll(logPath) return err != nil, err, nil diff --git a/base/main_test.go b/base/main_test.go index aee02f865f..1c806bd17d 100644 --- a/base/main_test.go +++ b/base/main_test.go @@ -11,10 +11,12 @@ licenses/APL2.txt. package base import ( + "context" "testing" ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - TestBucketPoolNoIndexes(m, tbpOptions) + TestBucketPoolNoIndexes(ctx, m, tbpOptions) } diff --git a/base/main_test_bucket_pool.go b/base/main_test_bucket_pool.go index ddea93002c..cf4ee18425 100644 --- a/base/main_test_bucket_pool.go +++ b/base/main_test_bucket_pool.go @@ -71,12 +71,12 @@ type TestBucketPoolOptions struct { UseDefaultScope bool } -func NewTestBucketPool(bucketReadierFunc TBPBucketReadierFunc, bucketInitFunc TBPBucketInitFunc) *TestBucketPool { - return NewTestBucketPoolWithOptions(bucketReadierFunc, bucketInitFunc, TestBucketPoolOptions{}) +func NewTestBucketPool(ctx context.Context, bucketReadierFunc TBPBucketReadierFunc, bucketInitFunc TBPBucketInitFunc) *TestBucketPool { + return NewTestBucketPoolWithOptions(ctx, bucketReadierFunc, bucketInitFunc, TestBucketPoolOptions{}) } // NewTestBucketPool initializes a new TestBucketPool. To be called from TestMain for packages requiring test buckets. -func NewTestBucketPoolWithOptions(bucketReadierFunc TBPBucketReadierFunc, bucketInitFunc TBPBucketInitFunc, options TestBucketPoolOptions) *TestBucketPool { +func NewTestBucketPoolWithOptions(ctx context.Context, bucketReadierFunc TBPBucketReadierFunc, bucketInitFunc TBPBucketInitFunc, options TestBucketPoolOptions) *TestBucketPool { // We can safely skip setup when we want Walrus buckets to be used. // They'll be created on-demand via GetTestBucketAndSpec, // which is fast enough for Walrus that we don't need to prepare buckets ahead of time. @@ -92,15 +92,15 @@ func NewTestBucketPoolWithOptions(bucketReadierFunc TBPBucketReadierFunc, bucket return &tbp } - _, err := SetMaxFileDescriptors(5000) + // Used to manage cancellation of worker goroutines + ctx, ctxCancelFunc := context.WithCancel(ctx) + + _, err := SetMaxFileDescriptors(ctx, 5000) if err != nil { - FatalfCtx(context.TODO(), "couldn't set max file descriptors: %v", err) + FatalfCtx(ctx, "couldn't set max file descriptors: %v", err) } - numBuckets := tbpNumBuckets() - - // Used to manage cancellation of worker goroutines - ctx, ctxCancelFunc := context.WithCancel(context.Background()) + numBuckets := tbpNumBuckets(ctx) preserveBuckets, _ := strconv.ParseBool(os.Getenv(tbpEnvPreserve)) @@ -117,9 +117,9 @@ func NewTestBucketPoolWithOptions(bucketReadierFunc TBPBucketReadierFunc, bucket useDefaultScope: options.UseDefaultScope, } - tbp.cluster = newTestCluster(UnitTestUrl(), tbp.Logf) + tbp.cluster = newTestCluster(ctx, UnitTestUrl(), tbp.Logf) - useCollections, err := tbp.canUseNamedCollections() + useCollections, err := tbp.canUseNamedCollections(ctx) if err != nil { tbp.Fatalf(ctx, "%s", err) } @@ -130,14 +130,14 @@ func NewTestBucketPoolWithOptions(bucketReadierFunc TBPBucketReadierFunc, bucket // Start up an async readier worker to process dirty buckets go tbp.bucketReadierWorker(ctx, bucketReadierFunc) - err = tbp.removeOldTestBuckets() + err = tbp.removeOldTestBuckets(ctx) if err != nil { tbp.Fatalf(ctx, "Couldn't remove old test buckets: %v", err) } // Make sure the test buckets are created and put into the readier worker queue start := time.Now() - if err := tbp.createTestBuckets(numBuckets, tbpBucketQuotaMB(), bucketInitFunc); err != nil { + if err := tbp.createTestBuckets(numBuckets, tbpBucketQuotaMB(ctx), bucketInitFunc); err != nil { tbp.Fatalf(ctx, "Couldn't create test buckets: %v", err) } atomic.AddInt64(&tbp.stats.TotalBucketInitDurationNano, time.Since(start).Nanoseconds()) @@ -178,7 +178,7 @@ func (tbp *TestBucketPool) markBucketClosed(t testing.TB, b Bucket) { } func (tbp *TestBucketPool) checkForViewOpsQueueEmptied(ctx context.Context, bucketName string, c chan struct{}) { - if err, _ := RetryLoop(bucketName+"-emptyViewOps", func() (bool, error, interface{}) { + if err, _ := RetryLoop(ctx, bucketName+"-emptyViewOps", func() (bool, error, interface{}) { if len(c) > 0 { return true, fmt.Errorf("view op queue not cleared. remaining: %d", len(c)), nil } @@ -273,21 +273,21 @@ func (tbp *TestBucketPool) GetWalrusTestBucket(t testing.TB, url string) (b Buck // GetExistingBucket opens a bucket conection to an existing bucket func (tbp *TestBucketPool) GetExistingBucket(t testing.TB) (b Bucket, s BucketSpec, teardown func()) { - testCtx := TestCtx(t) + ctx := TestCtx(t) - bucketCluster := initV2Cluster(UnitTestUrl()) + bucketCluster := initV2Cluster(ctx, UnitTestUrl()) bucketName := tbpBucketName(TestUseExistingBucketName()) bucketSpec := getTestBucketSpec(bucketName) - bucketFromSpec, err := GetGocbV2BucketFromCluster(bucketCluster, bucketSpec, waitForReadyBucketTimeout, false) + bucketFromSpec, err := GetGocbV2BucketFromCluster(ctx, bucketCluster, bucketSpec, waitForReadyBucketTimeout, false) if err != nil { - tbp.Fatalf(testCtx, "couldn't get existing collection from cluster: %v", err) + tbp.Fatalf(ctx, "couldn't get existing collection from cluster: %v", err) } - DebugfCtx(context.TODO(), KeySGTest, "opened bucket %s", bucketName) + DebugfCtx(ctx, KeySGTest, "opened bucket %s", bucketName) return bucketFromSpec, bucketSpec, func() { - tbp.Logf(testCtx, "Teardown called - Closing connection to existing bucket") + tbp.Logf(ctx, "Teardown called - Closing connection to existing bucket") bucketFromSpec.Close() } } @@ -373,7 +373,7 @@ func (tbp *TestBucketPool) addBucketToReadierQueue(ctx context.Context, name tbp } // Close waits for any buckets to be cleaned, and closes the pool. -func (tbp *TestBucketPool) Close() { +func (tbp *TestBucketPool) Close(ctx context.Context) { if tbp == nil { // noop return @@ -386,8 +386,8 @@ func (tbp *TestBucketPool) Close() { } if tbp.cluster != nil { - if err := tbp.cluster.close(); err != nil { - tbp.Logf(context.Background(), "Couldn't close cluster connection: %v", err) + if err := tbp.cluster.close(ctx); err != nil { + tbp.Logf(ctx, "Couldn't close cluster connection: %v", err) } } @@ -395,7 +395,7 @@ func (tbp *TestBucketPool) Close() { } // removeOldTestBuckets removes all buckets starting with testBucketNamePrefix -func (tbp *TestBucketPool) removeOldTestBuckets() error { +func (tbp *TestBucketPool) removeOldTestBuckets(ctx context.Context) error { buckets, err := tbp.cluster.getBucketNames() if err != nil { return errors.Wrap(err, "couldn't retrieve buckets from cluster manager") @@ -405,7 +405,7 @@ func (tbp *TestBucketPool) removeOldTestBuckets() error { for _, b := range buckets { if strings.HasPrefix(b, tbpBucketNamePrefix) { - ctx := bucketNameCtx(context.Background(), b) + ctx := bucketNameCtx(ctx, b) tbp.Logf(ctx, "Removing old test bucket") wg.Add(1) @@ -457,10 +457,10 @@ func (tbp *TestBucketPool) createCollections(ctx context.Context, bucket Bucket) dynamicDataStore, ok := bucket.(sgbucket.DynamicDataStoreBucket) if !ok { - tbp.Fatalf(ctx, "Bucket doesn't support dynamic collection creation") + tbp.Fatalf(ctx, "Bucket doesn't support dynamic collection creation %T", bucket) } - for i := 0; i < tbpNumCollectionsPerBucket(); i++ { + for i := 0; i < tbpNumCollectionsPerBucket(ctx); i++ { scopeName := tbp.testScopeName() collectionName := fmt.Sprintf("%s%d", tbpCollectionPrefix, i) ctx := KeyspaceLogCtx(ctx, bucket.GetName(), scopeName, collectionName) @@ -501,12 +501,12 @@ func (tbp *TestBucketPool) createTestBuckets(numBuckets, bucketQuotaMB int, buck ctx := bucketNameCtx(ctx, bucketName) tbp.Logf(ctx, "Creating new test bucket") - err := tbp.cluster.insertBucket(bucketName, bucketQuotaMB) + err := tbp.cluster.insertBucket(ctx, bucketName, bucketQuotaMB) if err != nil { tbp.Fatalf(ctx, "Couldn't create test bucket: %v", err) } - bucket, err := tbp.cluster.openTestBucket(tbpBucketName(bucketName), waitForReadyBucketTimeout) + bucket, err := tbp.cluster.openTestBucket(ctx, tbpBucketName(bucketName), waitForReadyBucketTimeout) if err != nil { tbp.Fatalf(ctx, "Timed out trying to open new bucket: %v", err) } @@ -534,7 +534,7 @@ func (tbp *TestBucketPool) createTestBuckets(numBuckets, bucketQuotaMB int, buck b := openBuckets[testBucketName] itemName := "bucket" - if err, _ := RetryLoop(b.GetName()+"bucketInitRetry", func() (bool, error, interface{}) { + if err, _ := RetryLoop(ctx, b.GetName()+"bucketInitRetry", func() (bool, error, interface{}) { tbp.Logf(ctx, "Running %s through init function", itemName) ctx = KeyspaceLogCtx(ctx, b.GetName(), "", "") err := bucketInitFunc(ctx, b, tbp) @@ -578,14 +578,14 @@ loop: defer tbp.bucketReadierWaitGroup.Done() start := time.Now() - b, err := tbp.cluster.openTestBucket(testBucketName, waitForReadyBucketTimeout) + b, err := tbp.cluster.openTestBucket(ctx, testBucketName, waitForReadyBucketTimeout) ctx = KeyspaceLogCtx(ctx, b.GetName(), "", "") if err != nil { tbp.Logf(ctx, "Couldn't open bucket to get ready, got error: %v", err) return } - err, _ = RetryLoop(b.GetName()+"bucketReadierRetry", func() (bool, error, interface{}) { + err, _ = RetryLoop(ctx, b.GetName()+"bucketReadierRetry", func() (bool, error, interface{}) { tbp.Logf(ctx, "Running bucket through readier function") err = bucketReadierFunc(ctx, b, tbp) if err != nil { @@ -685,18 +685,18 @@ var N1QLBucketEmptierFunc TBPBucketReadierFunc = func(ctx context.Context, b Buc type tbpBucketName string // TestBucketPoolMain is used as TestMain in main_test.go packages -func TestBucketPoolMain(m *testing.M, bucketReadierFunc TBPBucketReadierFunc, bucketInitFunc TBPBucketInitFunc, +func TestBucketPoolMain(ctx context.Context, m *testing.M, bucketReadierFunc TBPBucketReadierFunc, bucketInitFunc TBPBucketInitFunc, options TestBucketPoolOptions) { // can't use defer because of os.Exit teardownFuncs := make([]func(), 0) - teardownFuncs = append(teardownFuncs, SetUpGlobalTestLogging(m)) + teardownFuncs = append(teardownFuncs, SetUpGlobalTestLogging(ctx, m)) teardownFuncs = append(teardownFuncs, SetUpGlobalTestProfiling(m)) teardownFuncs = append(teardownFuncs, SetUpGlobalTestMemoryWatermark(m, options.MemWatermarkThresholdMB)) SkipPrometheusStatsRegistration = true - GTestBucketPool = NewTestBucketPoolWithOptions(bucketReadierFunc, bucketInitFunc, options) - teardownFuncs = append(teardownFuncs, GTestBucketPool.Close) + GTestBucketPool = NewTestBucketPoolWithOptions(ctx, bucketReadierFunc, bucketInitFunc, options) + teardownFuncs = append(teardownFuncs, func() { GTestBucketPool.Close(ctx) }) // must be the last teardown function added to the list to correctly detect leaked goroutines teardownFuncs = append(teardownFuncs, SetUpTestGoroutineDump(m)) @@ -712,6 +712,6 @@ func TestBucketPoolMain(m *testing.M, bucketReadierFunc TBPBucketReadierFunc, bu } // TestBucketPoolNoIndexes runs a TestMain for packages that do not require creation of indexes -func TestBucketPoolNoIndexes(m *testing.M, options TestBucketPoolOptions) { - TestBucketPoolMain(m, FlushBucketEmptierFunc, NoopInitFunc, options) +func TestBucketPoolNoIndexes(ctx context.Context, m *testing.M, options TestBucketPoolOptions) { + TestBucketPoolMain(ctx, m, FlushBucketEmptierFunc, NoopInitFunc, options) } diff --git a/base/main_test_bucket_pool_config.go b/base/main_test_bucket_pool_config.go index f37071d9bb..471d0f6fb5 100644 --- a/base/main_test_bucket_pool_config.go +++ b/base/main_test_bucket_pool_config.go @@ -71,7 +71,8 @@ var tbpDefaultBucketSpec = BucketSpec{ // TestsUseNamedCollections returns true if the tests use named collections. func TestsUseNamedCollections() bool { - ok, err := GTestBucketPool.canUseNamedCollections() + ctx := context.Background() + ok, err := GTestBucketPool.canUseNamedCollections(ctx) return err == nil && ok } @@ -90,7 +91,7 @@ func TestsRequireMobileRBAC(t *testing.T) { } // canUseNamedCollections returns true if the cluster supports named collections, and they are also requested -func (tbp *TestBucketPool) canUseNamedCollections() (bool, error) { +func (tbp *TestBucketPool) canUseNamedCollections(ctx context.Context) (bool, error) { // walrus supports collections, but we need to query the server's version for capability check clusterSupport := true if tbp.cluster != nil { @@ -111,10 +112,10 @@ func (tbp *TestBucketPool) canUseNamedCollections() (bool, error) { useDefaultCollection, isSet := os.LookupEnv(tbpEnvUseDefaultCollection) if !isSet { if !queryStoreSupportsCollections { - tbp.Logf(context.TODO(), "GSI disabled - not using named collections") + tbp.Logf(ctx, "GSI disabled - not using named collections") return false, nil } - tbp.Logf(context.TODO(), "Will use named collections if cluster supports them: %v", clusterSupport) + tbp.Logf(ctx, "Will use named collections if cluster supports them: %v", clusterSupport) // use collections if running GSI and server >= 7 return clusterSupport, nil } @@ -136,52 +137,52 @@ func (tbp *TestBucketPool) canUseNamedCollections() (bool, error) { } // tbpNumBuckets returns the configured number of buckets to use in the pool. -func tbpNumBuckets() int { +func tbpNumBuckets(ctx context.Context) int { numBuckets := tbpDefaultBucketPoolSize if envPoolSize := os.Getenv(tbpEnvBucketPoolSize); envPoolSize != "" { var err error numBuckets, err = strconv.Atoi(envPoolSize) if err != nil { - FatalfCtx(context.TODO(), "Couldn't parse %s: %v", tbpEnvBucketPoolSize, err) + FatalfCtx(ctx, "Couldn't parse %s: %v", tbpEnvBucketPoolSize, err) } } return numBuckets } // tbpNumReplicasreturns the number of replicas to use in each bucket. -func tbpNumReplicas() uint32 { +func tbpNumReplicas(ctx context.Context) uint32 { numReplicas := os.Getenv(tbpEnvBucketNumReplicas) if numReplicas == "" { return 0 } replicas, err := strconv.Atoi(numReplicas) if err != nil { - FatalfCtx(context.TODO(), "Couldn't parse %s: %v", tbpEnvBucketPoolSize, err) + FatalfCtx(ctx, "Couldn't parse %s: %v", tbpEnvBucketPoolSize, err) } return uint32(replicas) } // tbpNumCollectionsPerBucket returns the configured number of collections prepared in a bucket. -func tbpNumCollectionsPerBucket() int { +func tbpNumCollectionsPerBucket(ctx context.Context) int { numCollectionsPerBucket := tbpDefaultCollectionPoolSize if envCollectionPoolSize := os.Getenv(tbpEnvCollectionPoolSize); envCollectionPoolSize != "" { var err error numCollectionsPerBucket, err = strconv.Atoi(envCollectionPoolSize) if err != nil { - FatalfCtx(context.TODO(), "Couldn't parse %s: %v", tbpEnvCollectionPoolSize, err) + FatalfCtx(ctx, "Couldn't parse %s: %v", tbpEnvCollectionPoolSize, err) } } return numCollectionsPerBucket } // tbpBucketQuotaMB returns the configured bucket RAM quota. -func tbpBucketQuotaMB() int { +func tbpBucketQuotaMB(ctx context.Context) int { bucketQuota := defaultBucketQuotaMB if envBucketQuotaMB := os.Getenv(tbpEnvBucketQuotaMB); envBucketQuotaMB != "" { var err error bucketQuota, err = strconv.Atoi(envBucketQuotaMB) if err != nil { - FatalfCtx(context.TODO(), "Couldn't parse %s: %v", tbpEnvBucketQuotaMB, err) + FatalfCtx(ctx, "Couldn't parse %s: %v", tbpEnvBucketQuotaMB, err) } } return bucketQuota diff --git a/base/main_test_bucket_pool_util.go b/base/main_test_bucket_pool_util.go index 0ea1e1c711..44b172709e 100644 --- a/base/main_test_bucket_pool_util.go +++ b/base/main_test_bucket_pool_util.go @@ -55,7 +55,7 @@ func RequireNumTestBuckets(t *testing.T, numRequired int) { // RequireNumTestDataStores skips the given test if there are not enough test buckets available to use. func RequireNumTestDataStores(t testing.TB, numRequired int) { TestRequiresCollections(t) - available := tbpNumCollectionsPerBucket() + available := tbpNumCollectionsPerBucket(TestCtx(t)) if available < numRequired { t.Skipf("Only had %d usable test data stores available (test requires %d)", available, numRequired) } @@ -68,5 +68,5 @@ func (tbp *TestBucketPool) numUsableBuckets() int { // so report back 10 to match a fully available CBS bucket pool. return 10 } - return tbpNumBuckets() - int(atomic.LoadUint32(&tbp.preservedBucketCount)) + return tbpNumBuckets(context.Background()) - int(atomic.LoadUint32(&tbp.preservedBucketCount)) } diff --git a/base/main_test_cluster.go b/base/main_test_cluster.go index ef13cf18a4..fe1fc12d4f 100644 --- a/base/main_test_cluster.go +++ b/base/main_test_cluster.go @@ -20,21 +20,21 @@ import ( // tbpCluster defines the required test bucket pool cluster operations type tbpCluster interface { getBucketNames() ([]string, error) - insertBucket(name string, quotaMB int) error + insertBucket(ctx context.Context, name string, quotaMB int) error removeBucket(name string) error - openTestBucket(name tbpBucketName, waitUntilReady time.Duration) (Bucket, error) + openTestBucket(ctx context.Context, name tbpBucketName, waitUntilReady time.Duration) (Bucket, error) supportsCollections() (bool, error) supportsMobileRBAC() (bool, error) isServerEnterprise() (bool, error) - close() error + close(context.Context) error } type clusterLogFunc func(ctx context.Context, format string, args ...interface{}) // newTestCluster returns a cluster based on the driver used by the defaultBucketSpec. Accepts a clusterLogFunc to support // cluster logging within a test bucket pool context -func newTestCluster(server string, logger clusterLogFunc) tbpCluster { - return newTestClusterV2(server, logger) +func newTestCluster(ctx context.Context, server string, logger clusterLogFunc) tbpCluster { + return newTestClusterV2(ctx, server, logger) } // tbpClusterV2 implements the tbpCluster interface for a gocb v2 cluster @@ -48,16 +48,16 @@ type tbpClusterV2 struct { var _ tbpCluster = &tbpClusterV2{} -func newTestClusterV2(server string, logger clusterLogFunc) *tbpClusterV2 { +func newTestClusterV2(ctx context.Context, server string, logger clusterLogFunc) *tbpClusterV2 { tbpCluster := &tbpClusterV2{} tbpCluster.logger = logger - tbpCluster.cluster = initV2Cluster(server) + tbpCluster.cluster = initV2Cluster(ctx, server) tbpCluster.server = server return tbpCluster } // initV2Cluster makes cluster connection. Callers must close. -func initV2Cluster(server string) *gocb.Cluster { +func initV2Cluster(ctx context.Context, server string) *gocb.Cluster { testClusterTimeout := 10 * time.Second spec := BucketSpec{ @@ -68,17 +68,17 @@ func initV2Cluster(server string) *gocb.Cluster { connStr, err := spec.GetGoCBConnString(nil) if err != nil { - FatalfCtx(context.TODO(), "error getting connection string: %v", err) + FatalfCtx(ctx, "error getting connection string: %v", err) } - securityConfig, err := GoCBv2SecurityConfig(&spec.TLSSkipVerify, spec.CACertPath) + securityConfig, err := GoCBv2SecurityConfig(ctx, &spec.TLSSkipVerify, spec.CACertPath) if err != nil { - FatalfCtx(context.TODO(), "Couldn't initialize cluster security config: %v", err) + FatalfCtx(ctx, "Couldn't initialize cluster security config: %v", err) } authenticatorConfig, authErr := GoCBv2Authenticator(TestClusterUsername(), TestClusterPassword(), spec.Certpath, spec.Keypath) if authErr != nil { - FatalfCtx(context.TODO(), "Couldn't initialize cluster authenticator config: %v", authErr) + FatalfCtx(ctx, "Couldn't initialize cluster authenticator config: %v", authErr) } timeoutsConfig := GoCBv2TimeoutsConfig(spec.BucketOpTimeout, StdlibDurationPtr(spec.GetViewQueryTimeout())) @@ -91,12 +91,12 @@ func initV2Cluster(server string) *gocb.Cluster { cluster, err := gocb.Connect(connStr, clusterOptions) if err != nil { - FatalfCtx(context.TODO(), "Couldn't connect to %q: %v", server, err) + FatalfCtx(ctx, "Couldn't connect to %q: %v", server, err) } const clusterReadyTimeout = 90 * time.Second err = cluster.WaitUntilReady(clusterReadyTimeout, nil) if err != nil { - FatalfCtx(context.TODO(), "Cluster not ready after %ds: %v", int(clusterReadyTimeout.Seconds()), err) + FatalfCtx(ctx, "Cluster not ready after %ds: %v", int(clusterReadyTimeout.Seconds()), err) } return cluster } @@ -130,7 +130,7 @@ func (c *tbpClusterV2) getBucketNames() ([]string, error) { return names, nil } -func (c *tbpClusterV2) insertBucket(name string, quotaMB int) error { +func (c *tbpClusterV2) insertBucket(ctx context.Context, name string, quotaMB int) error { settings := gocb.CreateBucketSettings{ BucketSettings: gocb.BucketSettings{ @@ -138,7 +138,7 @@ func (c *tbpClusterV2) insertBucket(name string, quotaMB int) error { RAMQuotaMB: uint64(quotaMB), BucketType: gocb.CouchbaseBucketType, FlushEnabled: true, - NumReplicas: tbpNumReplicas(), + NumReplicas: tbpNumReplicas(ctx), }, } @@ -153,14 +153,14 @@ func (c *tbpClusterV2) removeBucket(name string) error { } // openTestBucket opens the bucket of the given name for the gocb cluster in the given TestBucketPool. -func (c *tbpClusterV2) openTestBucket(testBucketName tbpBucketName, waitUntilReady time.Duration) (Bucket, error) { +func (c *tbpClusterV2) openTestBucket(ctx context.Context, testBucketName tbpBucketName, waitUntilReady time.Duration) (Bucket, error) { - bucketCluster := initV2Cluster(c.server) + bucketCluster := initV2Cluster(ctx, c.server) // bucketSpec := getTestBucketSpec(testBucketName, usingNamedCollections) bucketSpec := getTestBucketSpec(testBucketName) - bucketFromSpec, err := GetGocbV2BucketFromCluster(bucketCluster, bucketSpec, waitUntilReady, false) + bucketFromSpec, err := GetGocbV2BucketFromCluster(ctx, bucketCluster, bucketSpec, waitUntilReady, false) if err != nil { return nil, err } @@ -168,21 +168,21 @@ func (c *tbpClusterV2) openTestBucket(testBucketName tbpBucketName, waitUntilRea return bucketFromSpec, nil } -func (c *tbpClusterV2) close() error { +func (c *tbpClusterV2) close(ctx context.Context) error { // no close operations needed if c.cluster != nil { if err := c.cluster.Close(nil); err != nil { - c.logger(context.Background(), "Couldn't close cluster connection: %v", err) + c.logger(ctx, "Couldn't close cluster connection: %v", err) return err } } return nil } -func (c *tbpClusterV2) getMinClusterCompatVersion() int { +func (c *tbpClusterV2) getMinClusterCompatVersion(ctx context.Context) int { nodesMeta, err := c.cluster.Internal().GetNodesMetadata(nil) if err != nil { - FatalfCtx(context.Background(), "TEST: failed to fetch nodes metadata: %v", err) + FatalfCtx(ctx, "TEST: failed to fetch nodes metadata: %v", err) } if len(nodesMeta) < 1 { panic("invalid NodesMetadata: no nodes") @@ -215,44 +215,3 @@ func (c *tbpClusterV2) supportsMobileRBAC() (bool, error) { } return major >= 7 && minor >= 1, nil } - -// dropAllScopesAndCollections attempts to drop *all* non-_default scopes and collections from the bucket associated with the collection, except those used by the test bucket pool. Intended for test usage only. -func dropAllScopesAndCollections(bucket *gocb.Bucket) error { - cm := bucket.Collections() - scopes, err := cm.GetAllScopes(nil) - if err != nil { - if httpErr, ok := err.(gocb.HTTPError); ok && httpErr.StatusCode == 404 { - return ErrCollectionsUnsupported - } - WarnfCtx(context.TODO(), "Error getting scopes on bucket %s: %v Will retry.", MD(bucket.Name()).Redact(), err) - return err - } - - // For each non-default scope, drop them. - // For each collection within the default scope, drop them. - for _, scope := range scopes { - if scope.Name != DefaultScope && !strings.HasPrefix(scope.Name, tbpScopePrefix) { - scopeName := fmt.Sprintf("scope %s on bucket %s", MD(scope).Redact(), MD(bucket.Name()).Redact()) - TracefCtx(context.TODO(), KeyAll, "Dropping %s", scopeName) - if err := cm.DropScope(scope.Name, nil); err != nil { - WarnfCtx(context.TODO(), "Error dropping %s: %v Will retry.", scopeName, err) - return err - } - continue - } - - // can't delete _default scope - but we can delete the non-_default collections within it - for _, collection := range scope.Collections { - if collection.Name != DefaultCollection && !strings.HasPrefix(collection.Name, tbpCollectionPrefix) { - collectionName := fmt.Sprintf("collection %s in scope %s on bucket %s", MD(collection.Name).Redact(), MD(scope).Redact(), MD(bucket.Name()).Redact()) - TracefCtx(context.TODO(), KeyAll, "Dropping %s", collectionName) - if err := cm.DropCollection(collection, nil); err != nil { - WarnfCtx(context.TODO(), "Error dropping %s: %v Will retry.", collectionName, err) - return err - } - } - } - - } - return nil -} diff --git a/base/rlimit.go b/base/rlimit.go index 9c34b86a20..b6005e68eb 100644 --- a/base/rlimit.go +++ b/base/rlimit.go @@ -28,7 +28,7 @@ import ( // https://github.com/couchbase/sync_gateway/issues/1083 // - Hard limit vs Soft limit // http://unix.stackexchange.com/questions/29577/ulimit-difference-between-hard-and-soft-limits -func SetMaxFileDescriptors(requestedSoftFDLimit uint64) (uint64, error) { +func SetMaxFileDescriptors(ctx context.Context, requestedSoftFDLimit uint64) (uint64, error) { var limits syscall.Rlimit @@ -37,6 +37,7 @@ func SetMaxFileDescriptors(requestedSoftFDLimit uint64) (uint64, error) { } requiresUpdate, recommendedSoftFDLimit := getSoftFDLimit( + ctx, requestedSoftFDLimit, limits, ) @@ -52,7 +53,7 @@ func SetMaxFileDescriptors(requestedSoftFDLimit uint64) (uint64, error) { err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &limits) if err == nil { - InfofCtx(context.Background(), KeyAll, "Configured process to allow %d open file descriptors", recommendedSoftFDLimit) + InfofCtx(ctx, KeyAll, "Configured process to allow %d open file descriptors", recommendedSoftFDLimit) } return recommendedSoftFDLimit, err @@ -76,7 +77,7 @@ func SetMaxFileDescriptors(requestedSoftFDLimit uint64) (uint64, error) { // a lower limit than the system limit // 2. Only return a value that is LESS-THAN-OR-EQUAL to the existing hard limit // since trying to set something higher than the hard limit will fail -func getSoftFDLimit(requestedSoftFDLimit uint64, limit syscall.Rlimit) (requiresUpdate bool, recommendedSoftFDLimit uint64) { +func getSoftFDLimit(ctx context.Context, requestedSoftFDLimit uint64, limit syscall.Rlimit) (requiresUpdate bool, recommendedSoftFDLimit uint64) { currentSoftFdLimit := limit.Cur currentHardFdLimit := limit.Max @@ -84,14 +85,14 @@ func getSoftFDLimit(requestedSoftFDLimit uint64, limit syscall.Rlimit) (requires // Is the user requesting something that is less than the existing soft limit? if requestedSoftFDLimit <= currentSoftFdLimit { // yep, and there is no point in doing so, so return false for requiresUpdate. - DebugfCtx(context.Background(), KeyAll, "requestedSoftFDLimit < currentSoftFdLimit (%v <= %v) no action needed", requestedSoftFDLimit, currentSoftFdLimit) + DebugfCtx(ctx, KeyAll, "requestedSoftFDLimit < currentSoftFdLimit (%v <= %v) no action needed", requestedSoftFDLimit, currentSoftFdLimit) return false, currentSoftFdLimit } // Is the user requesting something higher than the existing hard limit? if requestedSoftFDLimit >= currentHardFdLimit { // yes, so just use the hard limit - InfofCtx(context.Background(), KeyAll, "requestedSoftFDLimit >= currentHardFdLimit (%v >= %v) capping at %v", requestedSoftFDLimit, currentHardFdLimit, currentHardFdLimit) + InfofCtx(ctx, KeyAll, "requestedSoftFDLimit >= currentHardFdLimit (%v >= %v) capping at %v", requestedSoftFDLimit, currentHardFdLimit, currentHardFdLimit) return true, currentHardFdLimit } diff --git a/base/rlimit_test.go b/base/rlimit_test.go index c87b3acd6c..40bed853c2 100644 --- a/base/rlimit_test.go +++ b/base/rlimit_test.go @@ -31,8 +31,9 @@ func TestGetSoftFDLimitWithCurrent(t *testing.T) { Cur: currentSoftFdLimit, Max: currentHardFdLimit, } - + ctx := TestCtx(t) requiresUpdate, _ := getSoftFDLimit( + ctx, requestedSoftFDLimit, limit, ) @@ -41,6 +42,7 @@ func TestGetSoftFDLimitWithCurrent(t *testing.T) { limit.Cur = uint64(512) requiresUpdate, softFDLimit := getSoftFDLimit( + ctx, requestedSoftFDLimit, limit, ) @@ -67,27 +69,28 @@ func TestSetMaxFileDescriptors(t *testing.T) { }() // noop - n, err := SetMaxFileDescriptors(0) + ctx := TestCtx(t) + n, err := SetMaxFileDescriptors(ctx, 0) assert.NoError(t, err) assert.Equal(t, 0, int(n)) // noop (current limit < new limit) - n, err = SetMaxFileDescriptors(newLimits.Cur - 1) + n, err = SetMaxFileDescriptors(ctx, newLimits.Cur-1) assert.NoError(t, err) assert.Equal(t, 0, int(n)) // noop (current limit == new limit) - n, err = SetMaxFileDescriptors(newLimits.Cur) + n, err = SetMaxFileDescriptors(ctx, newLimits.Cur) assert.NoError(t, err) assert.Equal(t, 0, int(n)) // increase - n, err = SetMaxFileDescriptors(newLimits.Cur + 2) + n, err = SetMaxFileDescriptors(ctx, newLimits.Cur+2) assert.NoError(t, err) assert.Equal(t, int(newLimits.Cur+2), int(n)) // noop (we don't decrease limits) - n, err = SetMaxFileDescriptors(newLimits.Cur + 1) + n, err = SetMaxFileDescriptors(ctx, newLimits.Cur+1) assert.NoError(t, err) assert.Equal(t, 0, int(n)) } diff --git a/base/rlimit_windows.go b/base/rlimit_windows.go index a75edb6a55..9b8aa5a414 100644 --- a/base/rlimit_windows.go +++ b/base/rlimit_windows.go @@ -10,6 +10,8 @@ licenses/APL2.txt. package base -func SetMaxFileDescriptors(maxFDs uint64) (uint64, error) { +import "context" + +func SetMaxFileDescriptors(_ context.Context, maxFDs uint64) (uint64, error) { return 0, nil } diff --git a/base/sg_cluster_cfg.go b/base/sg_cluster_cfg.go index 75f6775fa1..baab08a996 100644 --- a/base/sg_cluster_cfg.go +++ b/base/sg_cluster_cfg.go @@ -44,11 +44,11 @@ var ErrCfgCasError = &cbgt.CfgCASError{} // // urlStr: single URL or multiple URLs delimited by ';' // bucket: couchbase bucket name -func NewCfgSG(datastore sgbucket.DataStore, keyPrefix string) (*CfgSG, error) { +func NewCfgSG(ctx context.Context, datastore sgbucket.DataStore, keyPrefix string) (*CfgSG, error) { cfgContextID := MD(datastore.GetName()).Redact() + "-cfgSG" // should this inherit DB context? - loggingCtx := LogContextWith(context.Background(), &LogContext{CorrelationID: cfgContextID}) + loggingCtx := LogContextWith(ctx, &LogContext{CorrelationID: cfgContextID}) c := &CfgSG{ datastore: datastore, diff --git a/base/util.go b/base/util.go index d1404ee0f4..3ff097f2fe 100644 --- a/base/util.go +++ b/base/util.go @@ -73,20 +73,20 @@ func NewNonCancelCtxForDatabase(dbName string, dbConsoleLogConfig *DbConsoleLogC } // RedactBasicAuthURLUserAndPassword returns the given string, with a redacted HTTP basic auth component. -func RedactBasicAuthURLUserAndPassword(urlIn string) string { +func RedactBasicAuthURLUserAndPassword(ctx context.Context, urlIn string) string { redactedUrl, err := RedactBasicAuthURL(urlIn, false) if err != nil { - WarnfCtx(context.Background(), "%v", err) + WarnfCtx(ctx, "%v", err) return "" } return redactedUrl } // RedactBasicAuthURLPassword returns the given string, with a redacted HTTP basic auth password component. -func RedactBasicAuthURLPassword(urlIn string) string { +func RedactBasicAuthURLPassword(ctx context.Context, urlIn string) string { redactedUrl, err := RedactBasicAuthURL(urlIn, true) if err != nil { - WarnfCtx(context.Background(), "%v", err) + WarnfCtx(ctx, "%v", err) return "" } return redactedUrl @@ -442,11 +442,7 @@ func (r *RetryTimeoutError) Error() string { return fmt.Sprintf("RetryLoop for %v giving up after %v attempts", r.description, r.attempts) } -func RetryLoop(description string, worker RetryWorker, sleeper RetrySleeper) (error, interface{}) { - return RetryLoopCtx(description, worker, sleeper, context.Background()) -} - -func RetryLoopCtx(description string, worker RetryWorker, sleeper RetrySleeper, ctx context.Context) (error, interface{}) { +func RetryLoop(ctx context.Context, description string, worker RetryWorker, sleeper RetrySleeper) (error, interface{}) { numAttempts := 1 @@ -481,7 +477,7 @@ func RetryLoopCtx(description string, worker RetryWorker, sleeper RetrySleeper, // A version of RetryLoop that returns a strongly typed cas as uint64, to avoid interface conversion overhead for // high throughput operations. -func RetryLoopCas(description string, worker RetryCasWorker, sleeper RetrySleeper) (error, uint64) { +func RetryLoopCas(ctx context.Context, description string, worker RetryCasWorker, sleeper RetrySleeper) (error, uint64) { numAttempts := 1 @@ -498,10 +494,10 @@ func RetryLoopCas(description string, worker RetryCasWorker, sleeper RetrySleepe if err == nil { err = NewRetryTimeoutError(description, numAttempts) } - WarnfCtx(context.Background(), "RetryLoopCas for %v giving up after %v attempts", description, numAttempts) + WarnfCtx(ctx, "RetryLoopCas for %v giving up after %v attempts", description, numAttempts) return err, value } - DebugfCtx(context.Background(), KeyAll, "RetryLoopCas retrying %v after %v ms.", description, sleepMs) + DebugfCtx(ctx, KeyAll, "RetryLoopCas retrying %v after %v ms.", description, sleepMs) <-time.After(time.Millisecond * time.Duration(sleepMs)) @@ -1215,13 +1211,13 @@ func ExpvarFloatVal(val float64) *expvar.Float { } // Convert an expvar.Var to an int64. Return 0 if the expvar var is nil. -func ExpvarVar2Int(expvarVar expvar.Var) int64 { +func ExpvarVar2Int(ctx context.Context, expvarVar expvar.Var) int64 { if expvarVar == nil { return 0 } asInt, ok := expvarVar.(*expvar.Int) if !ok { - WarnfCtx(context.Background(), "ExpvarVar2Int could not convert %v to *expvar.Int", expvarVar) + WarnfCtx(ctx, "ExpvarVar2Int could not convert %v to *expvar.Int", expvarVar) return 0 } return asInt.Value() @@ -1834,8 +1830,8 @@ func AllOrNoneNil(vals ...interface{}) bool { } // WaitForNoError runs the callback until it no longer returns an error. -func WaitForNoError(callback func() error) error { - err, _ := RetryLoop("wait for no error", func() (bool, error, interface{}) { +func WaitForNoError(ctx context.Context, callback func() error) error { + err, _ := RetryLoop(ctx, "wait for no error", func() (bool, error, interface{}) { callbackErr := callback() return callbackErr != nil, callbackErr, nil }, CreateMaxDoublingSleeperFunc(30, 10, 1000)) diff --git a/base/util_test.go b/base/util_test.go index 21648abefb..d7e3c59249 100644 --- a/base/util_test.go +++ b/base/util_test.go @@ -202,7 +202,7 @@ func TestRetryLoop(t *testing.T) { // Kick off retry loop description := fmt.Sprintf("TestRetryLoop") - err, result := RetryLoop(description, worker, sleeper) + err, result := RetryLoop(TestCtx(t), description, worker, sleeper) // We shouldn't get an error, because it will retry a few times and then succeed assert.True(t, err == nil) @@ -624,7 +624,7 @@ func TestRedactBasicAuthURL(t *testing.T) { for _, test := range tests { t.Run(test.input, func(t *testing.T) { - assert.Equal(t, test.expected, RedactBasicAuthURLUserAndPassword(test.input)) + assert.Equal(t, test.expected, RedactBasicAuthURLUserAndPassword(TestCtx(t), test.input)) }) } } diff --git a/base/util_testing.go b/base/util_testing.go index 777a16fd4e..8a0afb6c89 100644 --- a/base/util_testing.go +++ b/base/util_testing.go @@ -594,12 +594,12 @@ var GlobalTestLoggingSet = AtomicBool{} // SetUpGlobalTestLogging sets a global log level at runtime by using the SG_TEST_LOG_LEVEL environment variable. // This global level overrides any tests that specify their own test log level with SetUpTestLogging. -func SetUpGlobalTestLogging(m *testing.M) (teardownFn func()) { +func SetUpGlobalTestLogging(ctx context.Context, m *testing.M) (teardownFn func()) { if logLevel := os.Getenv(TestEnvGlobalLogLevel); logLevel != "" { var l LogLevel err := l.UnmarshalText([]byte(logLevel)) if err != nil { - FatalfCtx(context.TODO(), "TEST: Invalid log level used for %q: %s", TestEnvGlobalLogLevel, err) + FatalfCtx(ctx, "TEST: Invalid log level used for %q: %s", TestEnvGlobalLogLevel, err) } caller := GetCallersName(1, false) InfofCtx(context.Background(), KeyAll, "%s: Setup logging: level: %v - keys: %v", caller, logLevel, KeyAll) @@ -745,13 +745,13 @@ func DirExists(filename string) bool { } // WaitForStat will retry for up to 20 seconds until the result of getStatFunc is equal to the expected value. -func WaitForStat(getStatFunc func() int64, expected int64) (int64, bool) { +func WaitForStat(t testing.TB, getStatFunc func() int64, expected int64) (int64, bool) { workerFunc := func() (shouldRetry bool, err error, val interface{}) { val = getStatFunc() return val != expected, nil, val } // wait for up to 20 seconds for the stat to meet the expected value - err, val := RetryLoop("waitForStat retry loop", workerFunc, CreateSleeperFunc(200, 100)) + err, val := RetryLoop(TestCtx(t), "waitForStat retry loop", workerFunc, CreateSleeperFunc(200, 100)) valInt64, ok := val.(int64) return valInt64, err == nil && ok @@ -759,7 +759,7 @@ func WaitForStat(getStatFunc func() int64, expected int64) (int64, bool) { // RequireWaitForStat will retry for up to 20 seconds until the result of getStatFunc is equal to the expected value. func RequireWaitForStat(t testing.TB, getStatFunc func() int64, expected int64) { - val, ok := WaitForStat(getStatFunc, expected) + val, ok := WaitForStat(t, getStatFunc, expected) require.True(t, ok) require.Equal(t, expected, val) } @@ -767,7 +767,7 @@ func RequireWaitForStat(t testing.TB, getStatFunc func() int64, expected int64) // TestRequiresCollections will skip the current test if the Couchbase Server version it is running against does not // support collections. func TestRequiresCollections(t testing.TB) { - if ok, err := GTestBucketPool.canUseNamedCollections(); err != nil { + if ok, err := GTestBucketPool.canUseNamedCollections(TestCtx(t)); err != nil { t.Skipf("Skipping test - collections not supported: %v", err) } else if !ok { t.Skipf("Skipping test - collections not enabled") @@ -799,7 +799,7 @@ func CreateBucketScopesAndCollections(ctx context.Context, bucketSpec BucketSpec un, pw, _ := bucketSpec.Auth.GetCredentials() var rootCAs *x509.CertPool - if tlsConfig := bucketSpec.TLSConfig(); tlsConfig != nil { + if tlsConfig := bucketSpec.TLSConfig(ctx); tlsConfig != nil { rootCAs = tlsConfig.RootCAs } cluster, err := gocb.Connect(bucketSpec.Server, gocb.ClusterOptions{ @@ -831,7 +831,7 @@ func CreateBucketScopesAndCollections(ctx context.Context, bucketSpec BucketSpec return fmt.Errorf("failed to create collection %s in scope %s: %w", collectionName, scopeName, err) } DebugfCtx(ctx, KeySGTest, "Created collection %s.%s", scopeName, collectionName) - if err := WaitForNoError(func() error { + if err := WaitForNoError(ctx, func() error { _, err := cluster.Bucket(bucketSpec.BucketName).Scope(scopeName).Collection(collectionName).Exists("WaitForExists", nil) return err }); err != nil { diff --git a/channels/active_channels.go b/channels/active_channels.go index 272b2600d4..1cc7ea03b8 100644 --- a/channels/active_channels.go +++ b/channels/active_channels.go @@ -42,13 +42,13 @@ func NewActiveChannels(activeChannelCountStat *base.SgwIntStat) *ActiveChannels // Update changed increments/decrements active channel counts based on a set of changed channels. Triggered // when the set of channels being replicated by a given replication changes. -func (ac *ActiveChannels) UpdateChanged(collectionID uint32, changedChannels ChangedKeys) { +func (ac *ActiveChannels) UpdateChanged(ctx context.Context, collectionID uint32, changedChannels ChangedKeys) { ac.lock.Lock() for channelName, isIncrement := range changedChannels { if isIncrement { ac._incr(NewID(channelName, collectionID)) } else { - ac._decr(NewID(channelName, collectionID)) + ac._decr(ctx, NewID(channelName, collectionID)) } } @@ -64,11 +64,11 @@ func (ac *ActiveChannels) IncrChannels(collectionID uint32, timedSet TimedSet) { } } -func (ac *ActiveChannels) DecrChannels(collectionID uint32, timedSet TimedSet) { +func (ac *ActiveChannels) DecrChannels(ctx context.Context, collectionID uint32, timedSet TimedSet) { ac.lock.Lock() defer ac.lock.Unlock() for channelName, _ := range timedSet { - ac._decr(NewID(channelName, collectionID)) + ac._decr(ctx, NewID(channelName, collectionID)) } } @@ -85,9 +85,9 @@ func (ac *ActiveChannels) IncrChannel(channel ID) { ac.lock.Unlock() } -func (ac *ActiveChannels) DecrChannel(channel ID) { +func (ac *ActiveChannels) DecrChannel(ctx context.Context, channel ID) { ac.lock.Lock() - ac._decr(channel) + ac._decr(ctx, channel) ac.lock.Unlock() } @@ -99,10 +99,10 @@ func (ac *ActiveChannels) _incr(channel ID) { ac.channelCounts[channel] = current + 1 } -func (ac *ActiveChannels) _decr(channel ID) { +func (ac *ActiveChannels) _decr(ctx context.Context, channel ID) { current, ok := ac.channelCounts[channel] if !ok { - base.WarnfCtx(context.Background(), "Attempt made to decrement inactive channel %s - will be ignored", base.UD(channel)) + base.WarnfCtx(ctx, "Attempt made to decrement inactive channel %s - will be ignored", base.UD(channel)) return } if current <= 1 { diff --git a/channels/active_channels_test.go b/channels/active_channels_test.go index 388a92c840..f8d020b4d1 100644 --- a/channels/active_channels_test.go +++ b/channels/active_channels_test.go @@ -28,6 +28,7 @@ func TestActiveChannelsConcurrency(t *testing.T) { GHIChan := NewID("GHI", base.DefaultCollectionID) JKLChan := NewID("JKL", base.DefaultCollectionID) MNOChan := NewID("MNO", base.DefaultCollectionID) + ctx := base.TestCtx(t) // Concurrent Incr, Decr for i := 0; i < 50; i++ { wg.Add(1) @@ -40,7 +41,7 @@ func TestActiveChannelsConcurrency(t *testing.T) { inactiveChans := base.SetOf(ABCChan.Name, DEFChan.Name) inactiveChansTimedSet := AtSequence(inactiveChans, seqNo) - ac.DecrChannels(base.DefaultCollectionID, inactiveChansTimedSet) + ac.DecrChannels(ctx, base.DefaultCollectionID, inactiveChansTimedSet) }() } wg.Wait() @@ -57,9 +58,9 @@ func TestActiveChannelsConcurrency(t *testing.T) { go func() { defer wg.Done() changedKeys := ChangedKeys{"ABC": true, "DEF": true, "GHI": false, "MNO": true} - ac.UpdateChanged(base.DefaultCollectionID, changedKeys) + ac.UpdateChanged(ctx, base.DefaultCollectionID, changedKeys) changedKeys = ChangedKeys{"DEF": false} - ac.UpdateChanged(base.DefaultCollectionID, changedKeys) + ac.UpdateChanged(ctx, base.DefaultCollectionID, changedKeys) }() } wg.Wait() diff --git a/channels/channelmapper_test.go b/channels/channelmapper_test.go index 24eaedff7b..3cf7afbb69 100644 --- a/channels/channelmapper_test.go +++ b/channels/channelmapper_test.go @@ -38,14 +38,15 @@ func emptyMetaMap() map[string]interface{} { var noUser = map[string]interface{}{"name": nil, "channels": []string{}} func TestOttoValueToStringArray(t *testing.T) { + ctx := base.TestCtx(t) // Test for https://github.com/robertkrimen/otto/issues/24 value, _ := otto.New().ToValue([]string{"foo", "bar", "baz"}) - strings := ottoValueToStringArray(value) + strings := ottoValueToStringArray(ctx, value) assert.Equal(t, []string{"foo", "bar", "baz"}, strings) // Test for https://issues.couchbase.com/browse/CBG-714 value, _ = otto.New().ToValue([]interface{}{"a", []interface{}{"b", "g"}, "c", 4}) - strings = ottoValueToStringArray(value) + strings = ottoValueToStringArray(ctx, value) assert.Equal(t, []string{"a", "c"}, strings) } diff --git a/channels/main_test.go b/channels/main_test.go index 120c700f8e..f3877481b9 100644 --- a/channels/main_test.go +++ b/channels/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package channels import ( + "context" "os" "testing" @@ -18,9 +19,10 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process // can't use defer because of os.Exit teardownFuncs := make([]func(), 0) - teardownFuncs = append(teardownFuncs, base.SetUpGlobalTestLogging(m)) + teardownFuncs = append(teardownFuncs, base.SetUpGlobalTestLogging(ctx, m)) teardownFuncs = append(teardownFuncs, base.SetUpGlobalTestProfiling(m)) teardownFuncs = append(teardownFuncs, base.SetUpGlobalTestMemoryWatermark(m, 128)) diff --git a/channels/sync_runner.go b/channels/sync_runner.go index 72172aeaeb..063bab0d6a 100644 --- a/channels/sync_runner.go +++ b/channels/sync_runner.go @@ -129,7 +129,7 @@ func NewSyncRunner(ctx context.Context, funcSource string, timeout time.Duration // Implementation of the 'channel()' callback: runner.DefineNativeFunction("channel", func(call otto.FunctionCall) otto.Value { for _, arg := range call.ArgumentList { - if strings := ottoValueToStringArray(arg); strings != nil { + if strings := ottoValueToStringArray(ctx, arg); strings != nil { runner.channels = append(runner.channels, strings...) } } @@ -138,12 +138,12 @@ func NewSyncRunner(ctx context.Context, funcSource string, timeout time.Duration // Implementation of the 'access()' callback: runner.DefineNativeFunction("access", func(call otto.FunctionCall) otto.Value { - return runner.addValueForUser(call.Argument(0), call.Argument(1), runner.access) + return runner.addValueForUser(ctx, call.Argument(0), call.Argument(1), runner.access) }) // Implementation of the 'role()' callback: runner.DefineNativeFunction("role", func(call otto.FunctionCall) otto.Value { - return runner.addValueForUser(call.Argument(0), call.Argument(1), runner.roles) + return runner.addValueForUser(ctx, call.Argument(0), call.Argument(1), runner.roles) }) // Implementation of the 'reject()' callback: @@ -218,10 +218,10 @@ func (runner *SyncRunner) SetFunction(funcSource string) (bool, error) { } // Common implementation of 'access()' and 'role()' callbacks -func (runner *SyncRunner) addValueForUser(user otto.Value, value otto.Value, mapping map[string][]string) otto.Value { - valueStrings := ottoValueToStringArray(value) +func (runner *SyncRunner) addValueForUser(ctx context.Context, user otto.Value, value otto.Value, mapping map[string][]string) otto.Value { + valueStrings := ottoValueToStringArray(ctx, value) if len(valueStrings) > 0 { - for _, name := range ottoValueToStringArray(user) { + for _, name := range ottoValueToStringArray(ctx, user) { mapping[name] = append(mapping[name], valueStrings...) } } @@ -258,13 +258,13 @@ func AccessNameToPrincipalName(accessPrincipalName string) (principalName string } // Converts a JS string or array into a Go string array. -func ottoValueToStringArray(value otto.Value) []string { +func ottoValueToStringArray(ctx context.Context, value otto.Value) []string { nativeValue, _ := value.Export() result, nonStrings := base.ValueToStringArray(nativeValue) if !value.IsNull() && !value.IsUndefined() && nonStrings != nil { - base.WarnfCtx(context.Background(), "Channel names must be string values only. Ignoring non-string channels: %s", base.UD(nonStrings)) + base.WarnfCtx(ctx, "Channel names must be string values only. Ignoring non-string channels: %s", base.UD(nonStrings)) } return result } diff --git a/db/active_replicator.go b/db/active_replicator.go index 17147f0207..527816693c 100644 --- a/db/active_replicator.go +++ b/db/active_replicator.go @@ -168,7 +168,7 @@ func (ar *ActiveReplicator) _onReplicationComplete() { } -func (ar *ActiveReplicator) State() (state string, errorMessage string) { +func (ar *ActiveReplicator) State(ctx context.Context) (state string, errorMessage string) { state = ReplicationStateStopped if ar.Push != nil { @@ -177,7 +177,7 @@ func (ar *ActiveReplicator) State() (state string, errorMessage string) { if ar.Pull != nil { pullState, pullErrorMessage := ar.Pull.getStateWithErrorMessage() - state = combinedState(state, pullState) + state = combinedState(ctx, state, pullState) if pullErrorMessage != "" { errorMessage = pullErrorMessage } @@ -186,12 +186,12 @@ func (ar *ActiveReplicator) State() (state string, errorMessage string) { return state, errorMessage } -func (ar *ActiveReplicator) GetStatus() *ReplicationStatus { +func (ar *ActiveReplicator) GetStatus(ctx context.Context) *ReplicationStatus { status := &ReplicationStatus{ ID: ar.ID, } - status.Status, status.ErrorMessage = ar.State() + status.Status, status.ErrorMessage = ar.State(ctx) if ar.Pull != nil { status.PullReplicationStatus = ar.Pull.GetStatus().PullReplicationStatus @@ -220,7 +220,10 @@ func connect(arc *activeReplicatorCommon, idSuffix string) (blipSender *blip.Sen } bsc = NewBlipSyncContext(arc.ctx, blipContext, arc.config.ActiveDB, blipContext.ID, arc.replicationStats) - bsc.loggingCtx = base.CorrelationIDLogCtx(context.Background(), arc.config.ID+idSuffix) + + bsc.loggingCtx = base.CorrelationIDLogCtx( + arc.config.ActiveDB.AddDatabaseLogContext(base.NewNonCancelCtx().Ctx), + arc.config.ID+idSuffix) // NewBlipSyncContext has already set deltas as disabled/enabled based on config.ActiveDB. // If deltas have been disabled in the replication config, override this value @@ -303,7 +306,7 @@ func base64UserInfo(i *url.Userinfo) string { // - if either replication is in error, return error // - if either replication is running, return running // - if both replications are stopped, return stopped -func combinedState(state1, state2 string) (combinedState string) { +func combinedState(ctx context.Context, state1, state2 string) (combinedState string) { if state1 == "" { return state2 } @@ -327,7 +330,7 @@ func combinedState(state1, state2 string) (combinedState string) { return ReplicationStateReconnecting } - base.InfofCtx(context.Background(), base.KeyReplicate, "Unhandled combination of replication states (%s, %s), returning %s", state1, state2, state1) + base.InfofCtx(ctx, base.KeyReplicate, "Unhandled combination of replication states (%s, %s), returning %s", state1, state2, state1) return state1 } diff --git a/db/active_replicator_checkpointer.go b/db/active_replicator_checkpointer.go index 7fe96be676..df1a86346b 100644 --- a/db/active_replicator_checkpointer.go +++ b/db/active_replicator_checkpointer.go @@ -564,7 +564,7 @@ func (c *Checkpointer) getRemoteCheckpoint() (checkpoint *replicationCheckpoint, CollectionIdx: c.collectionIdx, } - if err := rq.Send(c.blipSender); err != nil { + if err := rq.Send(c.ctx, c.blipSender); err != nil { return &replicationCheckpoint{}, err } @@ -597,7 +597,7 @@ func (c *Checkpointer) setRemoteCheckpoint(checkpoint *replicationCheckpoint) (n rq.RevID = &parentRev } - if err := rq.Send(c.blipSender); err != nil { + if err := rq.Send(c.ctx, c.blipSender); err != nil { return "", err } diff --git a/db/active_replicator_common.go b/db/active_replicator_common.go index 2ff157c07f..2c512cde2d 100644 --- a/db/active_replicator_common.go +++ b/db/active_replicator_common.go @@ -195,7 +195,7 @@ func (a *activeReplicatorCommon) reconnectLoop() { return err != nil, err, nil } - err, _ := base.RetryLoopCtx("replicator reconnect", retryFunc, sleeperFunc, ctx) + err, _ := base.RetryLoop(ctx, "replicator reconnect", retryFunc, sleeperFunc) // release timer associated with context deadline if deadlineCancel != nil { deadlineCancel() diff --git a/db/active_replicator_common_collections.go b/db/active_replicator_common_collections.go index 13cd4c2748..52abf9b2b9 100644 --- a/db/active_replicator_common_collections.go +++ b/db/active_replicator_common_collections.go @@ -167,7 +167,7 @@ func (arc *activeReplicatorCommon) _initCollections() ([]replicationCheckpoint, return nil, err } - collectionContext := newBlipSyncCollectionContext(dbCollection) + collectionContext := newBlipSyncCollectionContext(arc.blipSyncContext.loggingCtx, dbCollection) blipSyncCollectionContexts[i] = collectionContext collectionCheckpoints[i] = *checkpoint diff --git a/db/active_replicator_pull.go b/db/active_replicator_pull.go index 8e4ab5ec3e..a6b9c7e105 100644 --- a/db/active_replicator_pull.go +++ b/db/active_replicator_pull.go @@ -112,7 +112,7 @@ func (apr *ActivePullReplicator) _startPullNonCollection() error { if err != nil { return err } - apr.blipSyncContext.collections.setNonCollectionAware(newBlipSyncCollectionContext(defaultCollection)) + apr.blipSyncContext.collections.setNonCollectionAware(newBlipSyncCollectionContext(apr.ctx, defaultCollection)) if err := apr._initCheckpointer(nil); err != nil { // clean up anything we've opened so far @@ -150,7 +150,7 @@ func (apr *ActivePullReplicator) _subChanges(collectionIdx *int, since string) e Revocations: apr.config.PurgeOnRemoval, CollectionIdx: collectionIdx, } - return subChangesRequest.Send(apr.blipSender) + return subChangesRequest.Send(apr.ctx, apr.blipSender) } // Complete gracefully shuts down a replication, waiting for all in-flight revisions to be processed diff --git a/db/active_replicator_push.go b/db/active_replicator_push.go index 98b6e5eb04..7dac4886d4 100644 --- a/db/active_replicator_push.go +++ b/db/active_replicator_push.go @@ -287,7 +287,7 @@ func (apr *ActivePushReplicator) _startPushNonCollection() error { if err != nil { return err } - apr.blipSyncContext.collections.setNonCollectionAware(newBlipSyncCollectionContext(dbCollection)) + apr.blipSyncContext.collections.setNonCollectionAware(newBlipSyncCollectionContext(apr.ctx, dbCollection)) if err := apr._initCheckpointer(nil); err != nil { // clean up anything we've opened so far @@ -302,12 +302,8 @@ func (apr *ActivePushReplicator) _startPushNonCollection() error { DatabaseCollection: dbCollection, user: apr.config.ActiveDB.user, } - bh := blipHandler{ - BlipSyncContext: apr.blipSyncContext, - db: apr.config.ActiveDB, - collection: dbCollectionWithUser, - serialNumber: apr.blipSyncContext.incrementSerialNumber(), - } + bh := newBlipHandler(apr.ctx, apr.blipSyncContext, apr.config.ActiveDB, apr.blipSyncContext.incrementSerialNumber()) + bh.collection = dbCollectionWithUser var channels base.Set if filteredChannels := apr.config.getFilteredChannels(nil); len(filteredChannels) > 0 { diff --git a/db/active_replicator_push_collections.go b/db/active_replicator_push_collections.go index f38a943d6e..d81ec95651 100644 --- a/db/active_replicator_push_collections.go +++ b/db/active_replicator_push_collections.go @@ -47,13 +47,9 @@ func (apr *ActivePushReplicator) _startPushWithCollections() error { user: apr.config.ActiveDB.user, } - bh := blipHandler{ - BlipSyncContext: apr.blipSyncContext, - db: apr.config.ActiveDB, - collection: dbCollectionWithUser, - collectionIdx: collectionIdx, - serialNumber: apr.blipSyncContext.incrementSerialNumber(), - } + bh := newBlipHandler(apr.ctx, apr.blipSyncContext, apr.config.ActiveDB, apr.blipSyncContext.incrementSerialNumber()) + bh.collection = dbCollectionWithUser + bh.collectionIdx = collectionIdx var channels base.Set if filteredChannels := apr.config.getFilteredChannels(collectionIdx); len(filteredChannels) > 0 { diff --git a/db/attachment.go b/db/attachment.go index fdcd66e62e..abc01db2ef 100644 --- a/db/attachment.go +++ b/db/attachment.go @@ -169,8 +169,8 @@ func (db *DatabaseCollectionWithUser) retrieveAncestorAttachments(ctx context.Co // No non-pruned ancestor is available if commonAncestor := doc.History.findAncestorFromSet(doc.CurrentRev, docHistory); commonAncestor != "" { parentAttachments := make(map[string]interface{}) - commonAncestorGen := int64(genOfRevID(commonAncestor)) - for name, activeAttachment := range GetBodyAttachments(doc.Body()) { + commonAncestorGen := int64(genOfRevID(ctx, commonAncestor)) + for name, activeAttachment := range GetBodyAttachments(doc.Body(ctx)) { if attachmentMeta, ok := activeAttachment.(map[string]interface{}); ok { activeRevpos, ok := base.ToInt64(attachmentMeta["revpos"]) if ok && activeRevpos <= commonAncestorGen { @@ -329,24 +329,24 @@ func GetAttachmentVersion(meta map[string]interface{}) (int, bool) { } // GenerateProofOfAttachment returns a nonce and proof for an attachment body. -func GenerateProofOfAttachment(attachmentData []byte) (nonce []byte, proof string, err error) { +func GenerateProofOfAttachment(ctx context.Context, attachmentData []byte) (nonce []byte, proof string, err error) { nonce = make([]byte, 20) if _, err := rand.Read(nonce); err != nil { return nil, "", base.HTTPErrorf(http.StatusInternalServerError, fmt.Sprintf("Failed to generate random data: %s", err)) } - proof = ProveAttachment(attachmentData, nonce) - base.TracefCtx(context.Background(), base.KeyCRUD, "Generated nonce %v and proof %q for attachment: %v", nonce, proof, attachmentData) + proof = ProveAttachment(ctx, attachmentData, nonce) + base.TracefCtx(ctx, base.KeyCRUD, "Generated nonce %v and proof %q for attachment: %v", nonce, proof, attachmentData) return nonce, proof, nil } // ProveAttachment returns the proof for an attachment body and nonce pair. -func ProveAttachment(attachmentData, nonce []byte) (proof string) { +func ProveAttachment(ctx context.Context, attachmentData, nonce []byte) (proof string) { d := sha1.New() d.Write([]byte{byte(len(nonce))}) d.Write(nonce) d.Write(attachmentData) proof = "sha1-" + base64.StdEncoding.EncodeToString(d.Sum(nil)) - base.TracefCtx(context.Background(), base.KeyCRUD, "Generated proof %q using nonce %v for attachment: %v", proof, nonce, attachmentData) + base.TracefCtx(ctx, base.KeyCRUD, "Generated proof %q using nonce %v for attachment: %v", proof, nonce, attachmentData) return proof } diff --git a/db/attachment_compaction.go b/db/attachment_compaction.go index cd5fb914bf..d5e4ae4f64 100644 --- a/db/attachment_compaction.go +++ b/db/attachment_compaction.go @@ -143,7 +143,7 @@ func attachmentCompactMarkPhase(ctx context.Context, dataStore base.DataStore, c return 0, nil, "", err } - dcpClient, err := base.NewDCPClient(dcpFeedKey, callback, *clientOptions, bucket) + dcpClient, err := base.NewDCPClient(ctx, dcpFeedKey, callback, *clientOptions, bucket) if err != nil { base.WarnfCtx(ctx, "[%s] Failed to create attachment compaction DCP client! %v", compactionLoggingID, err) return 0, nil, "", err @@ -369,7 +369,7 @@ func attachmentCompactSweepPhase(ctx context.Context, dataStore base.DataStore, } base.InfofCtx(ctx, base.KeyAll, "[%s] Starting DCP feed %q for sweep phase of attachment compaction", compactionLoggingID, dcpFeedKey) - dcpClient, err := base.NewDCPClient(dcpFeedKey, callback, *clientOptions, bucket) + dcpClient, err := base.NewDCPClient(ctx, dcpFeedKey, callback, *clientOptions, bucket) if err != nil { base.WarnfCtx(ctx, "[%s] Failed to create attachment compaction DCP client! %v", compactionLoggingID, err) return 0, err @@ -506,7 +506,7 @@ func attachmentCompactCleanupPhase(ctx context.Context, dataStore base.DataStore return "", err } - dcpClient, err := base.NewDCPClient(dcpFeedKey, callback, *clientOptions, bucket) + dcpClient, err := base.NewDCPClient(ctx, dcpFeedKey, callback, *clientOptions, bucket) if err != nil { base.WarnfCtx(ctx, "[%s] Failed to create attachment compaction DCP client! %v", compactionLoggingID, err) return "", err diff --git a/db/attachment_compaction_test.go b/db/attachment_compaction_test.go index be96483a22..92712cce21 100644 --- a/db/attachment_compaction_test.go +++ b/db/attachment_compaction_test.go @@ -278,7 +278,7 @@ func TestAttachmentCleanupRollback(t *testing.T) { dcpFeedKey := GenerateCompactionDCPStreamName(t.Name(), CleanupPhase) clientOptions, err := getCompactionDCPClientOptions(collectionID, testDb.Options.GroupID, testDb.MetadataKeys.DCPCheckpointPrefix(testDb.Options.GroupID)) require.NoError(t, err) - dcpClient, err := base.NewDCPClient(dcpFeedKey, nil, *clientOptions, bucket) + dcpClient, err := base.NewDCPClient(ctx, dcpFeedKey, nil, *clientOptions, bucket) require.NoError(t, err) // alter dcp metadata to feed into the compaction manager @@ -294,7 +294,7 @@ func TestAttachmentCleanupRollback(t *testing.T) { err = testDb.AttachmentCompactionManager.Process.Run(ctx, map[string]interface{}{"database": testDb}, testDb.AttachmentCompactionManager.UpdateStatusClusterAware, terminator) require.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status AttachmentManagerResponse rawStatus, err := testDb.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) @@ -440,7 +440,7 @@ func TestAttachmentCompactionRunTwice(t *testing.T) { err = testDB2.AttachmentCompactionManager.Start(ctx2, map[string]interface{}{"database": testDB2, "dryRun": true}) assert.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status AttachmentManagerResponse rawStatus, err := testDB2.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) @@ -465,7 +465,7 @@ func TestAttachmentCompactionRunTwice(t *testing.T) { err = testDB2.AttachmentCompactionManager.Start(ctx2, map[string]interface{}{"database": testDB2, "dryRun": false}) assert.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status AttachmentManagerResponse rawStatus, err := testDB2.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) @@ -491,7 +491,7 @@ func TestAttachmentCompactionRunTwice(t *testing.T) { err = testDB1.AttachmentCompactionManager.Start(ctx1, map[string]interface{}{"database": testDB1}) assert.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status AttachmentManagerResponse rawStatus, err := testDB1.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) @@ -510,7 +510,7 @@ func TestAttachmentCompactionRunTwice(t *testing.T) { err = testDB1.AttachmentCompactionManager.Start(ctx1, map[string]interface{}{"database": testDB1}) assert.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status AttachmentManagerResponse rawStatus, err := testDB1.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) @@ -590,7 +590,7 @@ func TestAttachmentCompactionStopImmediateStart(t *testing.T) { err = testDB2.AttachmentCompactionManager.Start(ctx2, map[string]interface{}{"database": testDB2, "dryRun": true}) assert.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status AttachmentManagerResponse rawStatus, err := testDB2.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) @@ -615,7 +615,7 @@ func TestAttachmentCompactionStopImmediateStart(t *testing.T) { err = testDB2.AttachmentCompactionManager.Start(ctx2, map[string]interface{}{"database": testDB2, "dryRun": false}) assert.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status AttachmentManagerResponse rawStatus, err := testDB2.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) @@ -676,7 +676,7 @@ func TestAttachmentProcessError(t *testing.T) { assert.NoError(t, err) var status AttachmentManagerResponse - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { rawStatus, err := testDB1.AttachmentCompactionManager.GetStatus() assert.NoError(t, err) err = base.JSONUnmarshal(rawStatus, &status) @@ -716,7 +716,7 @@ func TestAttachmentDifferentVBUUIDsBetweenPhases(t *testing.T) { assert.Contains(t, err.Error(), "error opening stream for vb 0: VbUUID mismatch when failOnRollback set") } -func WaitForConditionWithOptions(successFunc func() bool, maxNumAttempts, timeToSleepMs int) error { +func WaitForConditionWithOptions(t testing.TB, successFunc func() bool, maxNumAttempts, timeToSleepMs int) error { waitForSuccess := func() (shouldRetry bool, err error, value interface{}) { if successFunc() { return false, nil, nil @@ -725,7 +725,7 @@ func WaitForConditionWithOptions(successFunc func() bool, maxNumAttempts, timeTo } sleeper := base.CreateSleeperFunc(maxNumAttempts, timeToSleepMs) - err, _ := base.RetryLoop("Wait for condition options", waitForSuccess, sleeper) + err, _ := base.RetryLoop(base.TestCtx(t), "Wait for condition options", waitForSuccess, sleeper) if err != nil { return err } @@ -1000,11 +1000,11 @@ func TestAttachmentCompactIncorrectStat(t *testing.T) { // The timeToSleepMs here is low to ensure that this retry loop finishes after the mark starts, but before it has time to finish timeToSleepMs = 10 ) - err, _ := base.RetryLoop("wait for marking to start", statAboveZeroRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) + err, _ := base.RetryLoop(ctx, "wait for marking to start", statAboveZeroRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) require.NoError(t, err) terminator.Close() // Terminate mark function - err, _ = base.RetryLoop("wait for marking function to return", compactionFuncReturnedRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) + err, _ = base.RetryLoop(ctx, "wait for marking function to return", compactionFuncReturnedRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) require.NoError(t, err) // Allow time for timing issue to be hit where stat increments when it shouldn't time.Sleep(time.Second * 1) @@ -1024,11 +1024,11 @@ func TestAttachmentCompactIncorrectStat(t *testing.T) { }() // The timeToSleepMs here is low to ensure that this retry loop finishes after the sweep starts, but before it has time to finish - err, _ = base.RetryLoop("wait for sweeping to start", statAboveZeroRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) + err, _ = base.RetryLoop(ctx, "wait for sweeping to start", statAboveZeroRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) require.NoError(t, err) terminator.Close() // Terminate sweep function - err, _ = base.RetryLoop("wait for sweeping function to return", compactionFuncReturnedRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) + err, _ = base.RetryLoop(ctx, "wait for sweeping function to return", compactionFuncReturnedRetryFunc, base.CreateSleeperFunc(maxAttempts, timeToSleepMs)) require.NoError(t, err) // Allow time for timing issue to be hit where stat increments when it shouldn't time.Sleep(time.Second * 1) diff --git a/db/attachment_test.go b/db/attachment_test.go index 381a45c930..74345aa20a 100644 --- a/db/attachment_test.go +++ b/db/attachment_test.go @@ -491,14 +491,14 @@ func TestGenerateProofOfAttachment(t *testing.T) { base.SetUpTestLogging(t, base.LevelDebug, base.KeyAll) attData := []byte(`hello world`) - - nonce, proof1, err := GenerateProofOfAttachment(attData) + ctx := base.TestCtx(t) + nonce, proof1, err := GenerateProofOfAttachment(ctx, attData) require.NoError(t, err) assert.True(t, len(nonce) >= 20, "nonce should be at least 20 bytes") assert.NotEmpty(t, proof1) assert.True(t, strings.HasPrefix(proof1, "sha1-")) - proof2 := ProveAttachment(attData, nonce) + proof2 := ProveAttachment(ctx, attData, nonce) assert.NotEmpty(t, proof1, "") assert.True(t, strings.HasPrefix(proof1, "sha1-")) @@ -847,7 +847,7 @@ func TestMigrateBodyAttachments(t *testing.T) { require.NoError(t, err) // latest rev was 3-a when we called GetActive, make sure that hasn't changed. - gen, _ := ParseRevID(rev.RevID) + gen, _ := ParseRevID(ctx, rev.RevID) assert.Equal(t, 3, gen) // read-only operations don't "upgrade" the metadata, but it should still transform it on-demand before returning. @@ -878,7 +878,7 @@ func TestMigrateBodyAttachments(t *testing.T) { require.NoError(t, err) // latest rev was 3-a when we called Get, make sure that hasn't changed. - gen, _ := ParseRevID(rev.RevID) + gen, _ := ParseRevID(ctx, rev.RevID) assert.Equal(t, 3, gen) // read-only operations don't "upgrade" the metadata, but it should still transform it on-demand before returning. @@ -921,7 +921,7 @@ func TestMigrateBodyAttachments(t *testing.T) { newRevID, _, err := collection.Put(ctx, docKey, newBody) require.NoError(t, err) - gen, _ := ParseRevID(newRevID) + gen, _ := ParseRevID(ctx, newRevID) assert.Equal(t, 4, gen) // Verify attachments are in syncData returned from GetRev @@ -979,14 +979,14 @@ func TestMigrateBodyAttachments(t *testing.T) { newRevID, _, err := collection.Put(ctx, docKey, newBody) require.NoError(t, err) - gen, _ := ParseRevID(newRevID) + gen, _ := ParseRevID(ctx, newRevID) assert.Equal(t, 4, gen) // Verify attachments are now present via GetRev rev, err = collection.GetRev(ctx, docKey, newRevID, true, nil) require.NoError(t, err) - gen, _ = ParseRevID(rev.RevID) + gen, _ = ParseRevID(ctx, rev.RevID) assert.Equal(t, 4, gen) assert.Len(t, rev.Attachments, 2, "expecting 2 attachments returned in rev") diff --git a/db/background_mgr.go b/db/background_mgr.go index c9949624cb..1199b0a955 100644 --- a/db/background_mgr.go +++ b/db/background_mgr.go @@ -95,7 +95,7 @@ type BackgroundManagerProcessI interface { ResetStatus() } -type updateStatusCallbackFunc func() error +type updateStatusCallbackFunc func(ctx context.Context) error // GetName returns name of the background manager func (b *BackgroundManager) GetName() string { @@ -103,7 +103,7 @@ func (b *BackgroundManager) GetName() string { } func (b *BackgroundManager) Start(ctx context.Context, options map[string]interface{}) error { - err := b.markStart() + err := b.markStart(ctx) if err != nil { return err } @@ -132,7 +132,7 @@ func (b *BackgroundManager) Start(ctx context.Context, options map[string]interf for { select { case <-ticker.C: - err := b.UpdateStatusClusterAware() + err := b.UpdateStatusClusterAware(ctx) if err != nil { base.WarnfCtx(ctx, "Failed to update background manager status: %v", err) } @@ -164,7 +164,7 @@ func (b *BackgroundManager) Start(ctx context.Context, options map[string]interf // Once our background process run has completed we should update the completed status and delete the heartbeat // doc if b.isClusterAware() { - err := b.UpdateStatusClusterAware() + err := b.UpdateStatusClusterAware(ctx) if err != nil { base.WarnfCtx(ctx, "Failed to update background manager status: %v", err) } @@ -176,7 +176,7 @@ func (b *BackgroundManager) Start(ctx context.Context, options map[string]interf }() if b.isClusterAware() { - err := b.UpdateStatusClusterAware() + err := b.UpdateStatusClusterAware(ctx) if err != nil { base.ErrorfCtx(ctx, "Failed to update background manager status: %v", err) } @@ -185,7 +185,7 @@ func (b *BackgroundManager) Start(ctx context.Context, options map[string]interf return nil } -func (b *BackgroundManager) markStart() error { +func (b *BackgroundManager) markStart(ctx context.Context) error { b.lock.Lock() defer b.lock.Unlock() @@ -213,9 +213,9 @@ func (b *BackgroundManager) markStart() error { for { select { case <-ticker.C: - err = b.UpdateHeartbeatDocClusterAware() + err = b.UpdateHeartbeatDocClusterAware(ctx) if err != nil { - base.ErrorfCtx(context.TODO(), "Failed to update expiry on heartbeat doc: %v", err) + base.ErrorfCtx(ctx, "Failed to update expiry on heartbeat doc: %v", err) b.SetError(err) } case <-terminator.Done(): @@ -446,11 +446,11 @@ func (b *BackgroundManager) SetError(err error) { // UpdateStatusClusterAware gets the current local status from the running process and updates the status document in // the bucket. Implements a retry. Used for Cluster Aware operations -func (b *BackgroundManager) UpdateStatusClusterAware() error { +func (b *BackgroundManager) UpdateStatusClusterAware(ctx context.Context) error { if b.clusterAwareOptions == nil { return nil } - err, _ := base.RetryLoop("UpdateStatusClusterAware", func() (shouldRetry bool, err error, value interface{}) { + err, _ := base.RetryLoop(ctx, "UpdateStatusClusterAware", func() (shouldRetry bool, err error, value interface{}) { status, metadata, err := b.getStatusLocal() if err != nil { return true, err, nil @@ -477,7 +477,7 @@ type HeartbeatDoc struct { // UpdateHeartbeatDocClusterAware simply performs a touch operation on the heartbeat document to update its expiry. // Implements a retry. Used for Cluster Aware operations -func (b *BackgroundManager) UpdateHeartbeatDocClusterAware() error { +func (b *BackgroundManager) UpdateHeartbeatDocClusterAware(ctx context.Context) error { statusRaw, _, err := b.clusterAwareOptions.metadataStore.GetAndTouchRaw(b.clusterAwareOptions.HeartbeatDocID(), BackgroundManagerHeartbeatExpirySecs) if err != nil { // If we get an error but the error is doc not found and terminator closed it means we have terminated the @@ -505,7 +505,7 @@ func (b *BackgroundManager) UpdateHeartbeatDocClusterAware() error { if status.ShouldStop { err = b.Stop() if err != nil { - base.WarnfCtx(context.TODO(), "Failed to stop process %q: %v", b.clusterAwareOptions.processSuffix, err) + base.WarnfCtx(ctx, "Failed to stop process %q: %v", b.clusterAwareOptions.processSuffix, err) } } diff --git a/db/background_mgr_attachment_compaction.go b/db/background_mgr_attachment_compaction.go index de28b8a0e8..acf17e3cb7 100644 --- a/db/background_mgr_attachment_compaction.go +++ b/db/background_mgr_attachment_compaction.go @@ -113,9 +113,9 @@ func (a *AttachmentCompactionManager) PurgeDCPMetadata(ctx context.Context, data return err } - metadata := base.NewDCPMetadataCS(datastore, numVbuckets, base.DefaultNumWorkers, metadataKeyPrefix) + metadata := base.NewDCPMetadataCS(ctx, datastore, numVbuckets, base.DefaultNumWorkers, metadataKeyPrefix) base.InfofCtx(ctx, base.KeyDCP, "purging persisted dcp metadata for attachment compaction run %s", a.CompactID) - metadata.Purge(base.DefaultNumWorkers) + metadata.Purge(ctx, base.DefaultNumWorkers) return nil } @@ -132,7 +132,7 @@ func (a *AttachmentCompactionManager) Run(ctx context.Context, options map[strin var metadataKeyPrefix string persistClusterStatus := func() { - err := persistClusterStatusCallback() + err := persistClusterStatusCallback(ctx) if err != nil { base.WarnfCtx(ctx, "Failed to persist cluster status on-demand following completion of phase: %v", err) } @@ -157,7 +157,7 @@ func (a *AttachmentCompactionManager) Run(ctx context.Context, options map[strin return shouldRetry, err, nil } // retry loop for handling a rollback during mark phase of compaction process - err, _ = base.RetryLoop("attachmentCompactMarkPhase", worker, base.CreateMaxDoublingSleeperFunc(25, 100, 10000)) + err, _ = base.RetryLoop(ctx, "attachmentCompactMarkPhase", worker, base.CreateMaxDoublingSleeperFunc(25, 100, 10000)) if err != nil || terminator.IsClosed() { if errors.As(err, &rollbackErr) || errors.Is(err, base.ErrVbUUIDMismatch) { // log warning to show we hit max number of retries @@ -185,7 +185,7 @@ func (a *AttachmentCompactionManager) Run(ctx context.Context, options map[strin return shouldRetry, err, nil } // retry loop for handling a rollback during mark phase of compaction process - err, _ = base.RetryLoop("attachmentCompactCleanupPhase", worker, base.CreateMaxDoublingSleeperFunc(25, 100, 10000)) + err, _ = base.RetryLoop(ctx, "attachmentCompactCleanupPhase", worker, base.CreateMaxDoublingSleeperFunc(25, 100, 10000)) if err != nil || terminator.IsClosed() { if errors.As(err, &rollbackErr) || errors.Is(err, base.ErrVbUUIDMismatch) { // log warning to show we hit max number of retries diff --git a/db/background_mgr_resync.go b/db/background_mgr_resync.go index 469bd2dd85..2e2baf86e1 100644 --- a/db/background_mgr_resync.go +++ b/db/background_mgr_resync.go @@ -51,7 +51,7 @@ func (r *ResyncManager) Run(ctx context.Context, options map[string]interface{}, resyncCollections := options["collections"].(ResyncCollections) persistClusterStatus := func() { - err := persistClusterStatusCallback() + err := persistClusterStatusCallback(ctx) if err != nil { base.WarnfCtx(ctx, "Failed to persist cluster status on-demand for resync operation: %v", err) } diff --git a/db/background_mgr_resync_dcp.go b/db/background_mgr_resync_dcp.go index 4503ed0c26..b0dd0d76d5 100644 --- a/db/background_mgr_resync_dcp.go +++ b/db/background_mgr_resync_dcp.go @@ -96,7 +96,7 @@ func (r *ResyncManagerDCP) Run(ctx context.Context, options map[string]interface resyncLoggingID := "Resync: " + r.ResyncID persistClusterStatus := func() { - err := persistClusterStatusCallback() + err := persistClusterStatusCallback(ctx) if err != nil { base.WarnfCtx(ctx, "[%s] Failed to persist cluster status on-demand for resync operation: %v", resyncLoggingID, err) } @@ -160,7 +160,7 @@ func (r *ResyncManagerDCP) Run(ctx context.Context, options map[string]interface clientOptions := getReSyncDCPClientOptions(collectionIDs, db.Options.GroupID, db.MetadataKeys.DCPCheckpointPrefix(db.Options.GroupID)) dcpFeedKey := generateResyncDCPStreamName(r.ResyncID) - dcpClient, err := base.NewDCPClient(dcpFeedKey, callback, *clientOptions, bucket) + dcpClient, err := base.NewDCPClient(ctx, dcpFeedKey, callback, *clientOptions, bucket) if err != nil { base.WarnfCtx(ctx, "[%s] Failed to create resync DCP client! %v", resyncLoggingID, err) return err diff --git a/db/background_mgr_resync_dcp_test.go b/db/background_mgr_resync_dcp_test.go index 9367cb9f31..5ae3d3ffee 100644 --- a/db/background_mgr_resync_dcp_test.go +++ b/db/background_mgr_resync_dcp_test.go @@ -124,7 +124,7 @@ func TestResyncDCPInit(t *testing.T) { require.NoError(t, err) } - err = resycMgr.Process.Init(context.TODO(), options, clusterData) + err = resycMgr.Process.Init(ctx, options, clusterData) require.NoError(t, err) response := getResyncStats(resycMgr.Process) @@ -170,7 +170,7 @@ func TestResyncManagerDCPStopInMidWay(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { stats := getResyncStats(resycMgr.Process) if stats.DocsProcessed > 300 { err = resycMgr.Stop() @@ -182,7 +182,7 @@ func TestResyncManagerDCPStopInMidWay(t *testing.T) { require.NoError(t, err) }() - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status BackgroundManagerStatus rawStatus, _ := resycMgr.GetStatus() _ = json.Unmarshal(rawStatus, &status) @@ -224,7 +224,7 @@ func TestResyncManagerDCPStart(t *testing.T) { err := resyncMgr.Start(ctx, options) require.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status BackgroundManagerStatus rawStatus, _ := resyncMgr.GetStatus() _ = json.Unmarshal(rawStatus, &status) @@ -259,7 +259,7 @@ func TestResyncManagerDCPStart(t *testing.T) { err := resyncMgr.Start(ctx, options) require.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status BackgroundManagerStatus rawStatus, _ := resyncMgr.GetStatus() _ = json.Unmarshal(rawStatus, &status) @@ -306,7 +306,7 @@ func TestResyncManagerDCPRunTwice(t *testing.T) { // Attempt to Start running process go func() { defer wg.Done() - err := WaitForConditionWithOptions(func() bool { + err := WaitForConditionWithOptions(t, func() bool { stats := getResyncStats(resycMgr.Process) return stats.DocsProcessed > 100 }, 100, 200) @@ -317,7 +317,7 @@ func TestResyncManagerDCPRunTwice(t *testing.T) { assert.Contains(t, err.Error(), "Process already running") }() - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status BackgroundManagerStatus rawStatus, _ := resycMgr.GetStatus() _ = json.Unmarshal(rawStatus, &status) @@ -375,7 +375,7 @@ func TestResycnManagerDCPResumeStoppedProcess(t *testing.T) { } }() - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status BackgroundManagerStatus rawStatus, _ := resycMgr.GetStatus() _ = json.Unmarshal(rawStatus, &status) @@ -391,7 +391,7 @@ func TestResycnManagerDCPResumeStoppedProcess(t *testing.T) { err = resycMgr.Start(ctx, options) require.NoError(t, err) - err = WaitForConditionWithOptions(func() bool { + err = WaitForConditionWithOptions(t, func() bool { var status BackgroundManagerStatus rawStatus, _ := resycMgr.GetStatus() _ = json.Unmarshal(rawStatus, &status) diff --git a/db/blip_collection_context.go b/db/blip_collection_context.go index 0345026056..9699bc8db5 100644 --- a/db/blip_collection_context.go +++ b/db/blip_collection_context.go @@ -49,12 +49,12 @@ type blipCollections struct { const kMaxPendingInsertions = 1000 // newBlipSyncCollection constructs a context to hold all blip data for a given collection. -func newBlipSyncCollectionContext(dbCollection *DatabaseCollection) *blipSyncCollectionContext { +func newBlipSyncCollectionContext(ctx context.Context, dbCollection *DatabaseCollection) *blipSyncCollectionContext { c := &blipSyncCollectionContext{ dbCollection: dbCollection, pendingInsertions: base.Set{}, } - c.changesCtx, c.changesCtxCancel = context.WithCancel(context.Background()) + c.changesCtx, c.changesCtxCancel = context.WithCancel(base.KeyspaceLogCtx(ctx, dbCollection.bucketName(), dbCollection.ScopeName, dbCollection.Name)) return c } diff --git a/db/blip_handler.go b/db/blip_handler.go index a625bf7886..563f0d3ff5 100644 --- a/db/blip_handler.go +++ b/db/blip_handler.go @@ -64,6 +64,15 @@ type blipHandler struct { serialNumber uint64 // This blip handler's serial number to differentiate logs w/ other handlers } +func newBlipHandler(ctx context.Context, bc *BlipSyncContext, db *Database, serialNumber uint64) *blipHandler { + return &blipHandler{ + BlipSyncContext: bc, + db: db, + loggingCtx: ctx, + serialNumber: serialNumber, + } +} + // BlipSyncContextClientType represents whether to replicate to another Sync Gateway or Couchbase Lite type BLIPSyncContextClientType string @@ -155,7 +164,7 @@ func collectionBlipHandler(next blipHandlerFunc) blipHandlerFunc { } bh.collectionCtx, err = bh.collections.get(nil) if err != nil { - bh.collections.setNonCollectionAware(newBlipSyncCollectionContext(bh.collection.DatabaseCollection)) + bh.collections.setNonCollectionAware(newBlipSyncCollectionContext(bh.loggingCtx, bh.collection.DatabaseCollection)) bh.collectionCtx, _ = bh.collections.get(nil) } return next(bh, bm) @@ -242,7 +251,7 @@ func (bh *blipHandler) handleSetCheckpoint(rq *blip.Message) error { func (bh *blipHandler) handleSubChanges(rq *blip.Message) error { defaultSince := CreateZeroSinceValue() latestSeq := func() (SequenceID, error) { - seq, err := bh.collection.LastSequence() + seq, err := bh.collection.LastSequence(bh.loggingCtx) return SequenceID{Seq: seq}, err } subChangesParams, err := NewSubChangesParams(bh.loggingCtx, rq, defaultSince, latestSeq, ParseJSONSequenceID) @@ -264,7 +273,7 @@ func (bh *blipHandler) handleSubChanges(rq *blip.Message) error { // Create ctx if it has been cancelled if collectionCtx.changesCtx.Err() != nil { - collectionCtx.changesCtx, collectionCtx.changesCtxCancel = context.WithCancel(context.Background()) + collectionCtx.changesCtx, collectionCtx.changesCtxCancel = context.WithCancel(bh.loggingCtx) } if len(subChangesParams.docIDs()) > 0 && subChangesParams.continuous() { @@ -496,7 +505,7 @@ func (bh *blipHandler) sendChanges(sender *blip.Sender, opts *sendChangesOptions if bh.db.User() != nil { user = bh.db.User().Name() } - bh.db.DatabaseContext.NotifyTerminatedChanges(user) + bh.db.DatabaseContext.NotifyTerminatedChanges(bh.loggingCtx, user) } return !forceClose @@ -688,7 +697,7 @@ func (bh *blipHandler) handleChanges(rq *blip.Message) error { // already have this rev, tell the peer to skip sending it output.Write([]byte("0")) if collectionCtx.sgr2PullAlreadyKnownSeqsCallback != nil { - seq, err := ParseJSONSequenceID(seqStr(change[0])) + seq, err := ParseJSONSequenceID(seqStr(bh.loggingCtx, change[0])) if err != nil { base.WarnfCtx(bh.loggingCtx, "Unable to parse known sequence %q for %q / %q: %v", change[0], base.UD(docID), revID, err) } else { @@ -710,7 +719,7 @@ func (bh *blipHandler) handleChanges(rq *blip.Message) error { // skip parsing seqno if we're not going to use it (no callback defined) if collectionCtx.sgr2PullAddExpectedSeqsCallback != nil { - seq, err := ParseJSONSequenceID(seqStr(change[0])) + seq, err := ParseJSONSequenceID(seqStr(bh.loggingCtx, change[0])) if err != nil { // We've already asked for the doc/rev for the sequence so assume we're going to receive it... Just log this and carry on base.WarnfCtx(bh.loggingCtx, "Unable to parse expected sequence %q for %q / %q: %v", change[0], base.UD(docID), revID, err) @@ -1009,7 +1018,7 @@ func (bh *blipHandler) processRev(rq *blip.Message, stats *processRevStats) (err } deltaSrcMap := map[string]interface{}(deltaSrcBody) - err = base.Patch(&deltaSrcMap, newDoc.Body()) + err = base.Patch(&deltaSrcMap, newDoc.Body(bh.loggingCtx)) // err should only ever be a FleeceDeltaError here - but to be defensive, handle other errors too (e.g. somehow reaching this code in a CE build) if err != nil { // Something went wrong in the diffing library. We want to know about this! @@ -1022,14 +1031,14 @@ func (bh *blipHandler) processRev(rq *blip.Message, stats *processRevStats) (err stats.deltaRecvCount.Add(1) } - err = validateBlipBody(bodyBytes, newDoc) + err = validateBlipBody(bh.loggingCtx, bodyBytes, newDoc) if err != nil { return err } // Handle and pull out expiry if bytes.Contains(bodyBytes, []byte(BodyExpiry)) { - body := newDoc.Body() + body := newDoc.Body(bh.loggingCtx) expiry, err := body.ExtractExpiry() if err != nil { return base.HTTPErrorf(http.StatusBadRequest, "Invalid expiry: %v", err) @@ -1060,7 +1069,7 @@ func (bh *blipHandler) processRev(rq *blip.Message, stats *processRevStats) (err // Pull out attachments if injectedAttachmentsForDelta || bytes.Contains(bodyBytes, []byte(BodyAttachments)) { - body := newDoc.Body() + body := newDoc.Body(bh.loggingCtx) var currentBucketDoc *Document @@ -1073,12 +1082,12 @@ func (bh *blipHandler) processRev(rq *blip.Message, stats *processRevStats) (err // Otherwise we'll have to go as far back as we can in the doc history and choose the last entry in there. if err == nil { commonAncestor := currentDoc.History.findAncestorFromSet(currentDoc.CurrentRev, history) - minRevpos, _ = ParseRevID(commonAncestor) + minRevpos, _ = ParseRevID(bh.loggingCtx, commonAncestor) minRevpos++ rawBucketDoc = rawDoc currentBucketDoc = currentDoc } else { - minRevpos, _ = ParseRevID(history[len(history)-1]) + minRevpos, _ = ParseRevID(bh.loggingCtx, history[len(history)-1]) } } @@ -1096,7 +1105,7 @@ func (bh *blipHandler) processRev(rq *blip.Message, stats *processRevStats) (err if !ok { // If we don't have this attachment already, ensure incoming revpos is greater than minRevPos, otherwise // update to ensure it's fetched and uploaded - bodyAtts[name].(map[string]interface{})["revpos"], _ = ParseRevID(revID) + bodyAtts[name].(map[string]interface{})["revpos"], _ = ParseRevID(bh.loggingCtx, revID) continue } @@ -1136,7 +1145,7 @@ func (bh *blipHandler) processRev(rq *blip.Message, stats *processRevStats) (err // digest is different we need to override the revpos and set it to the current revision to ensure // the attachment is requested and stored if int(incomingAttachmentRevpos) <= minRevpos && currentAttachmentDigest != incomingAttachmentDigest { - bodyAtts[name].(map[string]interface{})["revpos"], _ = ParseRevID(revID) + bodyAtts[name].(map[string]interface{})["revpos"], _ = ParseRevID(bh.loggingCtx, revID) } } @@ -1232,7 +1241,7 @@ func (bh *blipHandler) handleProveAttachment(rq *blip.Message) error { return base.HTTPErrorf(http.StatusInternalServerError, fmt.Sprintf("Error getting client attachment: %v", err)) } - proof := ProveAttachment(attData, nonce) + proof := ProveAttachment(bh.loggingCtx, attData, nonce) resp := rq.Response() resp.SetBody([]byte(proof)) @@ -1353,7 +1362,7 @@ func (bh *blipHandler) sendGetAttachment(sender *blip.Sender, docID string, name // This is to prevent clients from creating a doc with a digest for an attachment they otherwise can't access, in order to download it. func (bh *blipHandler) sendProveAttachment(sender *blip.Sender, docID, name, digest string, knownData []byte) error { base.DebugfCtx(bh.loggingCtx, base.KeySync, " Verifying attachment %q for doc %s (digest %s)", base.UD(name), base.UD(docID), digest) - nonce, proof, err := GenerateProofOfAttachment(knownData) + nonce, proof, err := GenerateProofOfAttachment(bh.loggingCtx, knownData) if err != nil { return err } diff --git a/db/blip_handler_collections.go b/db/blip_handler_collections.go index 94c96efbff..9604737282 100644 --- a/db/blip_handler_collections.go +++ b/db/blip_handler_collections.go @@ -80,7 +80,7 @@ func (bh *blipHandler) handleGetCollections(rq *blip.Message) error { status, _ := base.ErrorAsHTTPStatus(err) if status == http.StatusNotFound { checkpoints[i] = Body{} - collectionContexts[i] = newBlipSyncCollectionContext(collection) + collectionContexts[i] = newBlipSyncCollectionContext(bh.loggingCtx, collection) } else { errMsg := fmt.Sprintf("Unable to fetch client checkpoint %q for collection %s: %s", key, scopeAndCollection, err) base.WarnfCtx(bh.loggingCtx, errMsg) @@ -90,7 +90,7 @@ func (bh *blipHandler) handleGetCollections(rq *blip.Message) error { } delete(value, BodyId) checkpoints[i] = value - collectionContexts[i] = newBlipSyncCollectionContext(collection) + collectionContexts[i] = newBlipSyncCollectionContext(bh.loggingCtx, collection) } bh.collections.set(collectionContexts) response := rq.Response() diff --git a/db/blip_messages.go b/db/blip_messages.go index 5f999a88b6..af70cb693a 100644 --- a/db/blip_messages.go +++ b/db/blip_messages.go @@ -22,7 +22,7 @@ import ( // blipMessageSender validates specific request types type blipMessageSender interface { - Send(s *blip.Sender) (err error) + Send(ctx context.Context, s *blip.Sender) (err error) } // SubChangesRequest is a strongly typed 'subChanges' request. @@ -41,8 +41,8 @@ type SubChangesRequest struct { var _ blipMessageSender = &SubChangesRequest{} -func (rq *SubChangesRequest) Send(s *blip.Sender) error { - r, err := rq.marshalBLIPRequest() +func (rq *SubChangesRequest) Send(ctx context.Context, s *blip.Sender) error { + r, err := rq.marshalBLIPRequest(ctx) if err != nil { return err } @@ -53,7 +53,7 @@ func (rq *SubChangesRequest) Send(s *blip.Sender) error { return nil } -func (rq *SubChangesRequest) marshalBLIPRequest() (*blip.Message, error) { +func (rq *SubChangesRequest) marshalBLIPRequest(ctx context.Context) (*blip.Message, error) { msg := blip.NewRequest() msg.SetProfile(MessageSubChanges) @@ -70,7 +70,7 @@ func (rq *SubChangesRequest) marshalBLIPRequest() (*blip.Message, error) { if err := msg.SetJSONBody(map[string]interface{}{ "docIDs": rq.DocIDs, }); err != nil { - base.ErrorfCtx(context.Background(), "error marshalling docIDs slice into subChanges request: %v", err) + base.ErrorfCtx(ctx, "error marshalling docIDs slice into subChanges request: %v", err) return nil, err } } @@ -90,7 +90,7 @@ type SetSGR2CheckpointRequest struct { var _ blipMessageSender = &SetSGR2CheckpointRequest{} -func (rq *SetSGR2CheckpointRequest) Send(s *blip.Sender) error { +func (rq *SetSGR2CheckpointRequest) Send(_ context.Context, s *blip.Sender) error { msg, err := rq.marshalBLIPRequest() if err != nil { return err @@ -161,7 +161,7 @@ type GetSGR2CheckpointRequest struct { var _ blipMessageSender = &GetSGR2CheckpointRequest{} -func (rq *GetSGR2CheckpointRequest) Send(s *blip.Sender) error { +func (rq *GetSGR2CheckpointRequest) Send(_ context.Context, s *blip.Sender) error { msg := rq.marshalBLIPRequest() if ok := s.Send(msg); !ok { diff --git a/db/blip_sync_context.go b/db/blip_sync_context.go index b90560a4e0..c6868c136f 100644 --- a/db/blip_sync_context.go +++ b/db/blip_sync_context.go @@ -169,11 +169,7 @@ func (bsc *BlipSyncContext) register(profile string, handlerFn func(*blipHandler }() startTime := time.Now() - handler := blipHandler{ - BlipSyncContext: bsc, - db: bsc.copyContextDatabase(), - serialNumber: bsc.incrementSerialNumber(), - } + handler := newBlipHandler(bsc.loggingCtx, bsc, bsc.copyContextDatabase(), bsc.incrementSerialNumber()) // Trace log the full message body and properties if base.LogTraceEnabled(base.KeySyncMsg) { @@ -181,7 +177,7 @@ func (bsc *BlipSyncContext) register(profile string, handlerFn func(*blipHandler base.TracefCtx(bsc.loggingCtx, base.KeySyncMsg, "Recv Req %s: Body: '%s' Properties: %v", rq, base.UD(rqBody), base.UD(rq.Properties)) } - if err := handlerFn(&handler, rq); err != nil { + if err := handlerFn(handler, rq); err != nil { status, msg := base.ErrorAsHTTPStatus(err) if response := rq.Response(); response != nil { response.SetError("HTTP", status, msg) diff --git a/db/change_cache.go b/db/change_cache.go index ad5fb9644e..264f96c40b 100644 --- a/db/change_cache.go +++ b/db/change_cache.go @@ -50,7 +50,7 @@ var EnableStarChannelLog = true // - Propagating DCP changes down to appropriate channel caches type changeCache struct { db *DatabaseContext - logCtx context.Context + logCtx context.Context // fix in sg-bucket to ProcessEvent logsDisabled bool // If true, ignore incoming tap changes nextSequence uint64 // Next consecutive sequence number to add. State variable for sequence buffering tracking. Should use getNextSequence() rather than accessing directly. initialSequence uint64 // DB's current sequence at startup time. Should use getInitialSequence() rather than accessing directly. @@ -79,7 +79,7 @@ type changeCacheStats struct { maxPending int } -func (c *changeCache) updateStats() { +func (c *changeCache) updateStats(ctx context.Context) { c.lock.Lock() defer c.lock.Unlock() @@ -89,7 +89,7 @@ func (c *changeCache) updateStats() { c.db.DbStats.Database().HighSeqFeed.SetIfMax(int64(c.internalStats.highSeqFeed)) c.db.DbStats.Cache().PendingSeqLen.Set(int64(c.internalStats.pendingSeqLen)) c.db.DbStats.CBLReplicationPull().MaxPending.SetIfMax(int64(c.internalStats.maxPending)) - c.db.DbStats.Cache().HighSeqStable.Set(int64(c._getMaxStableCached())) + c.db.DbStats.Cache().HighSeqStable.Set(int64(c._getMaxStableCached(ctx))) } @@ -162,9 +162,9 @@ func DefaultCacheOptions() CacheOptions { // After calling Init(), you must call .Start() to start useing the cache, otherwise it will be in a locked state // and callers will block on trying to obtain the lock. -func (c *changeCache) Init(logCtx context.Context, dbContext *DatabaseContext, channelCache ChannelCache, notifyChange func(context.Context, channels.Set), options *CacheOptions, metaKeys *base.MetadataKeys) error { +func (c *changeCache) Init(ctx context.Context, dbContext *DatabaseContext, channelCache ChannelCache, notifyChange func(context.Context, channels.Set), options *CacheOptions, metaKeys *base.MetadataKeys) error { c.db = dbContext - c.logCtx = logCtx + c.logCtx = ctx c.notifyChange = notifyChange c.receivedSeqs = make(map[uint64]struct{}) @@ -184,18 +184,18 @@ func (c *changeCache) Init(logCtx context.Context, dbContext *DatabaseContext, c c.channelCache = channelCache - base.InfofCtx(c.logCtx, base.KeyCache, "Initializing changes cache for %s with options %+v", base.UD(c.db.Name), c.options) + base.InfofCtx(ctx, base.KeyCache, "Initializing changes cache for %s with options %+v", base.UD(c.db.Name), c.options) heap.Init(&c.pendingLogs) // background tasks that perform housekeeping duties on the cache - bgt, err := NewBackgroundTask(c.logCtx, "InsertPendingEntries", c.InsertPendingEntries, c.options.CachePendingSeqMaxWait/2, c.terminator) + bgt, err := NewBackgroundTask(ctx, "InsertPendingEntries", c.InsertPendingEntries, c.options.CachePendingSeqMaxWait/2, c.terminator) if err != nil { return err } c.backgroundTasks = append(c.backgroundTasks, bgt) - bgt, err = NewBackgroundTask(c.logCtx, "CleanSkippedSequenceQueue", c.CleanSkippedSequenceQueue, c.options.CacheSkippedSeqMaxWait/2, c.terminator) + bgt, err = NewBackgroundTask(ctx, "CleanSkippedSequenceQueue", c.CleanSkippedSequenceQueue, c.options.CacheSkippedSeqMaxWait/2, c.terminator) if err != nil { return err } @@ -225,7 +225,7 @@ func (c *changeCache) Start(initialSequence uint64) error { } // Stops the cache. Clears its state and tells the housekeeping task to stop. -func (c *changeCache) Stop() { +func (c *changeCache) Stop(ctx context.Context) { if !c.started.IsTrue() { // changeCache never started - nothing to stop @@ -233,7 +233,7 @@ func (c *changeCache) Stop() { } if !c.stopped.CompareAndSwap(false, true) { - base.WarnfCtx(c.logCtx, "changeCache was already stopped") + base.WarnfCtx(ctx, "changeCache was already stopped") return } @@ -242,7 +242,7 @@ func (c *changeCache) Stop() { close(c.terminator) // Wait for changeCache background tasks to finish. - waitForBGTCompletion(context.TODO(), BGTCompletionMaxWait, c.backgroundTasks, c.db.Name) + waitForBGTCompletion(ctx, BGTCompletionMaxWait, c.backgroundTasks, c.db.Name) c.lock.Lock() c.logsDisabled = true @@ -250,7 +250,7 @@ func (c *changeCache) Stop() { } // Empty out all channel caches. -func (c *changeCache) Clear() error { +func (c *changeCache) Clear(ctx context.Context) error { c.lock.Lock() defer c.lock.Unlock() @@ -258,7 +258,7 @@ func (c *changeCache) Clear() error { // the point at which the change cache was initialized / re-initialized. // No need to touch c.nextSequence here, because we don't want to touch the sequence buffering state. var err error - c.initialSequence, err = c.db.LastSequence() + c.initialSequence, err = c.db.LastSequence(ctx) if err != nil { return err } @@ -289,7 +289,7 @@ func (c *changeCache) InsertPendingEntries(ctx context.Context) error { // Trigger _addPendingLogs to process any entries that have been pending too long: c.lock.Lock() - changedChannels := c._addPendingLogs() + changedChannels := c._addPendingLogs(ctx) if c.notifyChange != nil && len(changedChannels) > 0 { c.notifyChange(ctx, changedChannels) } @@ -326,7 +326,7 @@ func (c *changeCache) CleanSkippedSequenceQueue(ctx context.Context) error { // originating from multiple vbuckets). Only processEntry is locking - all other functionality needs to support // concurrent processing. func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { - + ctx := c.logCtx docID := string(event.Key) docJSON := event.Value changedChannelsCombined := channels.Set{} @@ -334,20 +334,20 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { // ** This method does not directly access any state of c, so it doesn't lock. // Is this a user/role doc for this database? if strings.HasPrefix(docID, c.metaKeys.UserKeyPrefix()) { - c.processPrincipalDoc(docID, docJSON, true, event.TimeReceived) + c.processPrincipalDoc(ctx, docID, docJSON, true, event.TimeReceived) return } else if strings.HasPrefix(docID, c.metaKeys.RoleKeyPrefix()) { - c.processPrincipalDoc(docID, docJSON, false, event.TimeReceived) + c.processPrincipalDoc(ctx, docID, docJSON, false, event.TimeReceived) return } // Is this an unused sequence notification? if strings.HasPrefix(docID, c.metaKeys.UnusedSeqPrefix()) { - c.processUnusedSequence(docID, event.TimeReceived) + c.processUnusedSequence(ctx, docID, event.TimeReceived) return } if strings.HasPrefix(docID, c.metaKeys.UnusedSeqRangePrefix()) { - c.processUnusedSequenceRange(docID) + c.processUnusedSequenceRange(ctx, docID) return } @@ -360,13 +360,13 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { // If this is a delete and there are no xattrs (no existing SG revision), we can ignore if event.Opcode == sgbucket.FeedOpDeletion && len(docJSON) == 0 { - base.DebugfCtx(c.logCtx, base.KeyImport, "Ignoring delete mutation for %s - no existing Sync Gateway metadata.", base.UD(docID)) + base.DebugfCtx(ctx, base.KeyImport, "Ignoring delete mutation for %s - no existing Sync Gateway metadata.", base.UD(docID)) return } collection, exists := c.db.CollectionByID[event.CollectionID] if !exists { - base.WarnfCtx(c.logCtx, "DocChanged: could not find collection with kv ID: %d", event.CollectionID) + base.WarnfCtx(ctx, "DocChanged: could not find collection with kv ID: %d", event.CollectionID) return } // If this is a binary document (and not one of the above types), we can ignore. Currently only performing this check when xattrs @@ -380,10 +380,10 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { if err != nil { // Avoid log noise related to failed unmarshaling of binary documents. if event.DataType != base.MemcachedDataTypeRaw { - base.DebugfCtx(c.logCtx, base.KeyCache, "Unable to unmarshal sync metadata for feed document %q. Will not be included in channel cache. Error: %v", base.UD(docID), err) + base.DebugfCtx(ctx, base.KeyCache, "Unable to unmarshal sync metadata for feed document %q. Will not be included in channel cache. Error: %v", base.UD(docID), err) } if err == base.ErrEmptyMetadata { - base.WarnfCtx(c.logCtx, "Unexpected empty metadata when processing feed event. docid: %s opcode: %v datatype:%v", base.UD(event.Key), event.Opcode, event.DataType) + base.WarnfCtx(ctx, "Unexpected empty metadata when processing feed event. docid: %s opcode: %v datatype:%v", base.UD(event.Key), event.Opcode, event.DataType) } return } @@ -402,12 +402,12 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { // If not using xattrs and no sync metadata found, check whether we're mid-upgrade and attempting to read a doc w/ metadata stored in xattr // before ignoring the mutation. if !collection.UseXattrs() && !syncData.HasValidSyncData() { - migratedDoc, _ := collection.checkForUpgrade(docID, DocUnmarshalNoHistory) + migratedDoc, _ := collection.checkForUpgrade(ctx, docID, DocUnmarshalNoHistory) if migratedDoc != nil && migratedDoc.Cas == event.Cas { - base.InfofCtx(c.logCtx, base.KeyCache, "Found mobile xattr on doc %q without %s property - caching, assuming upgrade in progress.", base.UD(docID), base.SyncPropertyName) + base.InfofCtx(ctx, base.KeyCache, "Found mobile xattr on doc %q without %s property - caching, assuming upgrade in progress.", base.UD(docID), base.SyncPropertyName) syncData = &migratedDoc.SyncData } else { - base.InfofCtx(c.logCtx, base.KeyCache, "changeCache: Doc %q does not have valid sync data.", base.UD(docID)) + base.InfofCtx(ctx, base.KeyCache, "changeCache: Doc %q does not have valid sync data.", base.UD(docID)) collection.dbStats().Cache().NonMobileIgnoredCount.Add(1) return } @@ -435,13 +435,13 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { // If the doc update wasted any sequences due to conflicts, add empty entries for them: for _, seq := range syncData.UnusedSequences { - base.InfofCtx(c.logCtx, base.KeyCache, "Received unused #%d in unused_sequences property for (%q / %q)", seq, base.UD(docID), syncData.CurrentRev) + base.InfofCtx(ctx, base.KeyCache, "Received unused #%d in unused_sequences property for (%q / %q)", seq, base.UD(docID), syncData.CurrentRev) change := &LogEntry{ Sequence: seq, TimeReceived: event.TimeReceived, CollectionID: event.CollectionID, } - changedChannels := c.processEntry(change) + changedChannels := c.processEntry(ctx, change) changedChannelsCombined = changedChannelsCombined.Update(changedChannels) } @@ -459,7 +459,7 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { for _, seq := range syncData.RecentSequences { if seq >= c.getNextSequence() && seq < currentSequence { - base.InfofCtx(c.logCtx, base.KeyCache, "Received deduplicated #%d in recent_sequences property for (%q / %q)", seq, base.UD(docID), syncData.CurrentRev) + base.InfofCtx(ctx, base.KeyCache, "Received deduplicated #%d in recent_sequences property for (%q / %q)", seq, base.UD(docID), syncData.CurrentRev) change := &LogEntry{ Sequence: seq, TimeReceived: event.TimeReceived, @@ -474,7 +474,7 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { change.Channels = channelRemovals } - changedChannels := c.processEntry(change) + changedChannels := c.processEntry(ctx, change) changedChannelsCombined = changedChannelsCombined.Update(changedChannels) } } @@ -499,17 +499,17 @@ func (c *changeCache) DocChanged(event sgbucket.FeedEvent) { // If latency is larger than 1 minute or is negative there is likely an issue and this should be clear to the user if millisecondLatency >= 60*1000 { - base.InfofCtx(c.logCtx, base.KeyDCP, "Received #%d after %3dms (%q / %q)", change.Sequence, millisecondLatency, base.UD(change.DocID), change.RevID) + base.InfofCtx(ctx, base.KeyDCP, "Received #%d after %3dms (%q / %q)", change.Sequence, millisecondLatency, base.UD(change.DocID), change.RevID) } else { - base.DebugfCtx(c.logCtx, base.KeyDCP, "Received #%d after %3dms (%q / %q)", change.Sequence, millisecondLatency, base.UD(change.DocID), change.RevID) + base.DebugfCtx(ctx, base.KeyDCP, "Received #%d after %3dms (%q / %q)", change.Sequence, millisecondLatency, base.UD(change.DocID), change.RevID) } - changedChannels := c.processEntry(change) + changedChannels := c.processEntry(ctx, change) changedChannelsCombined = changedChannelsCombined.Update(changedChannels) // Notify change listeners for all of the changed channels if c.notifyChange != nil && len(changedChannelsCombined) > 0 { - c.notifyChange(c.logCtx, changedChannelsCombined) + c.notifyChange(ctx, changedChannelsCombined) } } @@ -520,8 +520,8 @@ type cachePrincipal struct { Sequence uint64 `json:"sequence"` } -func (c *changeCache) Remove(collectionID uint32, docIDs []string, startTime time.Time) (count int) { - return c.channelCache.Remove(collectionID, docIDs, startTime) +func (c *changeCache) Remove(ctx context.Context, collectionID uint32, docIDs []string, startTime time.Time) (count int) { + return c.channelCache.Remove(ctx, collectionID, docIDs, startTime) } // Principals unmarshalled during caching don't need to instantiate a real principal - we're just using name and seq from the document @@ -532,27 +532,27 @@ func (c *changeCache) unmarshalCachePrincipal(docJSON []byte) (cachePrincipal, e } // Process unused sequence notification. Extracts sequence from docID and sends to cache for buffering -func (c *changeCache) processUnusedSequence(docID string, timeReceived time.Time) { +func (c *changeCache) processUnusedSequence(ctx context.Context, docID string, timeReceived time.Time) { sequenceStr := strings.TrimPrefix(docID, c.metaKeys.UnusedSeqPrefix()) sequence, err := strconv.ParseUint(sequenceStr, 10, 64) if err != nil { - base.WarnfCtx(c.logCtx, "Unable to identify sequence number for unused sequence notification with key: %s, error: %v", base.UD(docID), err) + base.WarnfCtx(ctx, "Unable to identify sequence number for unused sequence notification with key: %s, error: %v", base.UD(docID), err) return } - c.releaseUnusedSequence(sequence, timeReceived) + c.releaseUnusedSequence(ctx, sequence, timeReceived) } -func (c *changeCache) releaseUnusedSequence(sequence uint64, timeReceived time.Time) { +func (c *changeCache) releaseUnusedSequence(ctx context.Context, sequence uint64, timeReceived time.Time) { change := &LogEntry{ Sequence: sequence, TimeReceived: timeReceived, } - base.InfofCtx(c.logCtx, base.KeyCache, "Received #%d (unused sequence)", sequence) + base.InfofCtx(ctx, base.KeyCache, "Received #%d (unused sequence)", sequence) // Since processEntry may unblock pending sequences, if there were any changed channels we need // to notify any change listeners that are working changes feeds for these channels - changedChannels := c.processEntry(change) + changedChannels := c.processEntry(ctx, change) unusedSeq := channels.NewID(unusedSeqKey, unusedSeqCollectionID) if changedChannels == nil { changedChannels = channels.SetOfNoValidate(unusedSeq) @@ -561,12 +561,12 @@ func (c *changeCache) releaseUnusedSequence(sequence uint64, timeReceived time.T } c.channelCache.AddSkippedSequence(change) if c.notifyChange != nil && len(changedChannels) > 0 { - c.notifyChange(c.logCtx, changedChannels) + c.notifyChange(ctx, changedChannels) } } // Process unused sequence notification. Extracts sequence from docID and sends to cache for buffering -func (c *changeCache) processUnusedSequenceRange(docID string) { +func (c *changeCache) processUnusedSequenceRange(ctx context.Context, docID string) { // _sync:unusedSequences:fromSeq:toSeq sequencesStr := strings.TrimPrefix(docID, c.metaKeys.UnusedSeqRangePrefix()) sequences := strings.Split(sequencesStr, ":") @@ -576,29 +576,29 @@ func (c *changeCache) processUnusedSequenceRange(docID string) { fromSequence, err := strconv.ParseUint(sequences[0], 10, 64) if err != nil { - base.WarnfCtx(c.logCtx, "Unable to identify from sequence number for unused sequences notification with key: %s, error:", base.UD(docID), err) + base.WarnfCtx(ctx, "Unable to identify from sequence number for unused sequences notification with key: %s, error:", base.UD(docID), err) return } toSequence, err := strconv.ParseUint(sequences[1], 10, 64) if err != nil { - base.WarnfCtx(c.logCtx, "Unable to identify to sequence number for unused sequence notification with key: %s, error:", base.UD(docID), err) + base.WarnfCtx(ctx, "Unable to identify to sequence number for unused sequence notification with key: %s, error:", base.UD(docID), err) return } // TODO: There should be a more efficient way to do this for seq := fromSequence; seq <= toSequence; seq++ { - c.releaseUnusedSequence(seq, time.Now()) + c.releaseUnusedSequence(ctx, seq, time.Now()) } } -func (c *changeCache) processPrincipalDoc(docID string, docJSON []byte, isUser bool, timeReceived time.Time) { +func (c *changeCache) processPrincipalDoc(ctx context.Context, docID string, docJSON []byte, isUser bool, timeReceived time.Time) { // Currently the cache isn't really doing much with user docs; mostly it needs to know about // them because they have sequence numbers, so without them the sequence of sequences would // have gaps in it, causing later sequences to get stuck in the queue. princ, err := c.unmarshalCachePrincipal(docJSON) if err != nil { - base.WarnfCtx(c.logCtx, "changeCache: Error unmarshaling doc %q: %v", base.UD(docID), err) + base.WarnfCtx(ctx, "changeCache: Error unmarshaling doc %q: %v", base.UD(docID), err) return } sequence := princ.Sequence @@ -619,16 +619,16 @@ func (c *changeCache) processPrincipalDoc(docID string, docJSON []byte, isUser b change.DocID = "_role/" + princ.Name } - base.InfofCtx(c.logCtx, base.KeyDCP, "Received #%d (%q)", change.Sequence, base.UD(change.DocID)) + base.InfofCtx(ctx, base.KeyDCP, "Received #%d (%q)", change.Sequence, base.UD(change.DocID)) - changedChannels := c.processEntry(change) + changedChannels := c.processEntry(ctx, change) if c.notifyChange != nil && len(changedChannels) > 0 { - c.notifyChange(c.logCtx, changedChannels) + c.notifyChange(ctx, changedChannels) } } // Handles a newly-arrived LogEntry. -func (c *changeCache) processEntry(change *LogEntry) channels.Set { +func (c *changeCache) processEntry(ctx context.Context, change *LogEntry) channels.Set { c.lock.Lock() defer c.lock.Unlock() if c.logsDisabled { @@ -646,13 +646,13 @@ func (c *changeCache) processEntry(change *LogEntry) channels.Set { // We can cancel processing early in these scenarios. // Check if this is a duplicate of an already processed sequence if sequence < c.nextSequence && !c.WasSkipped(sequence) { - base.DebugfCtx(c.logCtx, base.KeyCache, " Ignoring duplicate of #%d", sequence) + base.DebugfCtx(ctx, base.KeyCache, " Ignoring duplicate of #%d", sequence) return nil } // Check if this is a duplicate of a pending sequence if _, found := c.receivedSeqs[sequence]; found { - base.DebugfCtx(c.logCtx, base.KeyCache, " Ignoring duplicate of #%d", sequence) + base.DebugfCtx(ctx, base.KeyCache, " Ignoring duplicate of #%d", sequence) return nil } c.receivedSeqs[sequence] = struct{}{} @@ -660,16 +660,16 @@ func (c *changeCache) processEntry(change *LogEntry) channels.Set { var changedChannels channels.Set if sequence == c.nextSequence || c.nextSequence == 0 { // This is the expected next sequence so we can add it now: - changedChannels = channels.SetFromArrayNoValidate(c._addToCache(change)) + changedChannels = channels.SetFromArrayNoValidate(c._addToCache(ctx, change)) // Also add any pending sequences that are now contiguous: - changedChannels = changedChannels.Update(c._addPendingLogs()) + changedChannels = changedChannels.Update(c._addPendingLogs(ctx)) } else if sequence > c.nextSequence { // There's a missing sequence (or several), so put this one on ice until it arrives: heap.Push(&c.pendingLogs, change) numPending := len(c.pendingLogs) c.internalStats.pendingSeqLen = numPending if base.LogDebugEnabled(base.KeyCache) { - base.DebugfCtx(c.logCtx, base.KeyCache, " Deferring #%d (%d now waiting for #%d...#%d) doc %q / %q", + base.DebugfCtx(ctx, base.KeyCache, " Deferring #%d (%d now waiting for #%d...#%d) doc %q / %q", sequence, numPending, c.nextSequence, c.pendingLogs[0].Sequence-1, base.UD(change.DocID), change.RevID) } // Update max pending high watermark stat @@ -679,25 +679,25 @@ func (c *changeCache) processEntry(change *LogEntry) channels.Set { if numPending > c.options.CachePendingSeqMaxNum { // Too many pending; add the oldest one: - changedChannels = c._addPendingLogs() + changedChannels = c._addPendingLogs(ctx) } } else if sequence > c.initialSequence { // Out-of-order sequence received! // Remove from skipped sequence queue if !c.WasSkipped(sequence) { // Error removing from skipped sequences - base.InfofCtx(c.logCtx, base.KeyCache, " Received unexpected out-of-order change - not in skippedSeqs (seq %d, expecting %d) doc %q / %q", sequence, c.nextSequence, base.UD(change.DocID), change.RevID) + base.InfofCtx(ctx, base.KeyCache, " Received unexpected out-of-order change - not in skippedSeqs (seq %d, expecting %d) doc %q / %q", sequence, c.nextSequence, base.UD(change.DocID), change.RevID) } else { - base.InfofCtx(c.logCtx, base.KeyCache, " Received previously skipped out-of-order change (seq %d, expecting %d) doc %q / %q ", sequence, c.nextSequence, base.UD(change.DocID), change.RevID) + base.InfofCtx(ctx, base.KeyCache, " Received previously skipped out-of-order change (seq %d, expecting %d) doc %q / %q ", sequence, c.nextSequence, base.UD(change.DocID), change.RevID) change.Skipped = true } - changedChannels = changedChannels.UpdateWithSlice(c._addToCache(change)) + changedChannels = changedChannels.UpdateWithSlice(c._addToCache(ctx, change)) // Add to cache before removing from skipped, to ensure lowSequence doesn't get incremented until results are available // in cache err := c.RemoveSkipped(sequence) if err != nil { - base.DebugfCtx(c.logCtx, base.KeyCache, "Error removing skipped sequence: #%d from cache: %v", sequence, err) + base.DebugfCtx(ctx, base.KeyCache, "Error removing skipped sequence: #%d from cache: %v", sequence, err) } } return changedChannels @@ -705,7 +705,7 @@ func (c *changeCache) processEntry(change *LogEntry) channels.Set { // Adds an entry to the appropriate channels' caches, returning the affected channels. lateSequence // flag indicates whether it was a change arriving out of sequence -func (c *changeCache) _addToCache(change *LogEntry) []channels.ID { +func (c *changeCache) _addToCache(ctx context.Context, change *LogEntry) []channels.ID { if change.Sequence >= c.nextSequence { c.nextSequence = change.Sequence + 1 @@ -724,9 +724,9 @@ func (c *changeCache) _addToCache(change *LogEntry) []channels.ID { // updatedChannels tracks the set of channels that should be notified of the change. This includes // the change's active channels, as well as any channel removals for the active revision. - updatedChannels := c.channelCache.AddToCache(change) + updatedChannels := c.channelCache.AddToCache(ctx, change) if base.LogDebugEnabled(base.KeyDCP) { - base.DebugfCtx(c.logCtx, base.KeyDCP, " #%d ==> channels %v", change.Sequence, base.UD(updatedChannels)) + base.DebugfCtx(ctx, base.KeyDCP, " #%d ==> channels %v", change.Sequence, base.UD(updatedChannels)) } if !change.TimeReceived.IsZero() { @@ -740,7 +740,7 @@ func (c *changeCache) _addToCache(change *LogEntry) []channels.ID { // Add the first change(s) from pendingLogs if they're the next sequence. If not, and we've been // waiting too long for nextSequence, move nextSequence to skipped queue. // Returns the channels that changed. -func (c *changeCache) _addPendingLogs() channels.Set { +func (c *changeCache) _addPendingLogs(ctx context.Context) channels.Set { var changedChannels channels.Set for len(c.pendingLogs) > 0 { @@ -748,10 +748,10 @@ func (c *changeCache) _addPendingLogs() channels.Set { isNext := change.Sequence == c.nextSequence if isNext { heap.Pop(&c.pendingLogs) - changedChannels = changedChannels.UpdateWithSlice(c._addToCache(change)) + changedChannels = changedChannels.UpdateWithSlice(c._addToCache(ctx, change)) } else if len(c.pendingLogs) > c.options.CachePendingSeqMaxNum || time.Since(c.pendingLogs[0].TimeReceived) >= c.options.CachePendingSeqMaxWait { c.db.DbStats.Cache().NumSkippedSeqs.Add(1) - c.PushSkipped(c.nextSequence) + c.PushSkipped(ctx, c.nextSequence) c.nextSequence++ } else { break @@ -790,10 +790,10 @@ func (c *changeCache) LastSequence() uint64 { return lastSequence } -func (c *changeCache) getOldestSkippedSequence() uint64 { +func (c *changeCache) getOldestSkippedSequence(ctx context.Context) uint64 { oldestSkippedSeq := c.skippedSeqs.getOldest() if oldestSkippedSeq > 0 { - base.DebugfCtx(c.logCtx, base.KeyChanges, "Get oldest skipped, returning: %d", oldestSkippedSeq) + base.DebugfCtx(ctx, base.KeyChanges, "Get oldest skipped, returning: %d", oldestSkippedSeq) } return oldestSkippedSeq } @@ -857,10 +857,10 @@ func (c *changeCache) WasSkipped(x uint64) bool { return c.skippedSeqs.Contains(x) } -func (c *changeCache) PushSkipped(sequence uint64) { +func (c *changeCache) PushSkipped(ctx context.Context, sequence uint64) { err := c.skippedSeqs.Push(&SkippedSequence{seq: sequence, timeAdded: time.Now()}) if err != nil { - base.InfofCtx(c.logCtx, base.KeyCache, "Error pushing skipped sequence: %d, %v", sequence, err) + base.InfofCtx(ctx, base.KeyCache, "Error pushing skipped sequence: %d, %v", sequence, err) return } c.db.DbStats.Cache().SkippedSeqLen.Set(int64(c.skippedSeqs.skippedList.Len())) @@ -885,7 +885,7 @@ func (c *changeCache) waitForSequence(ctx context.Context, sequence uint64, maxW ctx, cancel := context.WithDeadline(ctx, startTime.Add(maxWaitTime)) sleeper := base.SleeperFuncCtx(base.CreateMaxDoublingSleeperFunc(math.MaxInt64, 1, 100), ctx) - err, _ := base.RetryLoop(fmt.Sprintf("waitForSequence(%d)", sequence), worker, sleeper) + err, _ := base.RetryLoop(ctx, fmt.Sprintf("waitForSequence(%d)", sequence), worker, sleeper) cancel() return err } @@ -908,19 +908,19 @@ func (c *changeCache) waitForSequenceNotSkipped(ctx context.Context, sequence ui ctx, cancel := context.WithDeadline(ctx, startTime.Add(maxWaitTime)) sleeper := base.SleeperFuncCtx(base.CreateMaxDoublingSleeperFunc(math.MaxInt64, 1, 100), ctx) - err, _ := base.RetryLoop(fmt.Sprintf("waitForSequenceNotSkipped(%d)", sequence), worker, sleeper) + err, _ := base.RetryLoop(ctx, fmt.Sprintf("waitForSequenceNotSkipped(%d)", sequence), worker, sleeper) cancel() return err } -func (c *changeCache) getMaxStableCached() uint64 { +func (c *changeCache) getMaxStableCached(ctx context.Context) uint64 { c.lock.RLock() defer c.lock.RUnlock() - return c._getMaxStableCached() + return c._getMaxStableCached(ctx) } -func (c *changeCache) _getMaxStableCached() uint64 { - oldestSkipped := c.getOldestSkippedSequence() +func (c *changeCache) _getMaxStableCached(ctx context.Context) uint64 { + oldestSkipped := c.getOldestSkippedSequence(ctx) if oldestSkipped > 0 { return oldestSkipped - 1 } diff --git a/db/change_cache_test.go b/db/change_cache_test.go index fc08401c14..dbd0f24ee0 100644 --- a/db/change_cache_test.go +++ b/db/change_cache_test.go @@ -272,7 +272,7 @@ func TestLateSequenceErrorRecovery(t *testing.T) { // Start continuous changes feed var options ChangesOptions options.Since = SequenceID{Seq: 0} - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx defer changesCtxCancel() options.Continuous = true @@ -321,7 +321,7 @@ func TestLateSequenceErrorRecovery(t *testing.T) { // Modify the cache's late logs to remove the changes feed's lateFeedHandler sequence from the // cache's lateLogs. This will trigger an error on the next feed iteration, which should trigger // rollback to resend all changes since low sequence (1) - c, err := db.changeCache.getChannelCache().getSingleChannelCache(channels.NewID("ABC", collectionID)) + c, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, channels.NewID("ABC", collectionID)) require.NoError(t, err) abcCache := c.(*singleChannelCacheImpl) abcCache.lateLogs[0].logEntry.Sequence = 1 @@ -580,7 +580,7 @@ func TestChannelCacheBufferingWithUserDoc(t *testing.T) { successChan := make(chan bool) go func() { - waiter.Wait() + waiter.Wait(ctx) close(successChan) }() @@ -624,7 +624,7 @@ func TestChannelCacheBackfill(t *testing.T) { // Test that retrieval isn't blocked by skipped sequences require.NoError(t, db.changeCache.waitForSequence(ctx, 6, base.DefaultWaitForSequence)) collection.user, _ = authenticator.GetUser("naomi") - changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq()) + changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq(t)) assert.NoError(t, err, "Couldn't GetChanges") assert.Equal(t, 4, len(changes)) @@ -644,25 +644,25 @@ func TestChannelCacheBackfill(t *testing.T) { require.NoError(t, db.changeCache.waitForSequence(ctx, 7, base.DefaultWaitForSequence)) // verify insert at start (PBS) - pbsCache, err := db.changeCache.getChannelCache().getSingleChannelCache(channels.NewID("PBS", collectionID)) + pbsCache, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, channels.NewID("PBS", collectionID)) require.NoError(t, err) assert.True(t, verifyCacheSequences(pbsCache, []uint64{3, 5, 6})) // verify insert at middle (ABC) - abcCache, err := db.changeCache.getChannelCache().getSingleChannelCache(channels.NewID("ABC", collectionID)) + abcCache, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, channels.NewID("ABC", collectionID)) require.NoError(t, err) assert.True(t, verifyCacheSequences(abcCache, []uint64{1, 2, 3, 5, 6})) // verify insert at end (NBC) - nbcCache, err := db.changeCache.getChannelCache().getSingleChannelCache(channels.NewID("NBC", collectionID)) + nbcCache, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, channels.NewID("NBC", collectionID)) require.NoError(t, err) assert.True(t, verifyCacheSequences(nbcCache, []uint64{1, 3})) // verify insert to empty cache (TBS) - tbsCache, err := db.changeCache.getChannelCache().getSingleChannelCache(channels.NewID("TBS", collectionID)) + tbsCache, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, channels.NewID("TBS", collectionID)) require.NoError(t, err) assert.True(t, verifyCacheSequences(tbsCache, []uint64{3})) // verify changes has three entries (needs to resend all since previous LowSeq, which // will be the late arriver (3) along with 5, 6) - changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(lastSeq)) + changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(t, lastSeq)) require.NoError(t, err) assert.Equal(t, 3, len(changes)) assert.Equal(t, &ChangeEntry{ @@ -702,7 +702,7 @@ func TestContinuousChangesBackfill(t *testing.T) { // Start changes feed var options ChangesOptions options.Since = SequenceID{Seq: 0} - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx options.Continuous = true options.Wait = true @@ -807,7 +807,7 @@ func TestLowSequenceHandling(t *testing.T) { var options ChangesOptions options.Since = SequenceID{Seq: 0} - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx defer changesCtxCancel() options.Continuous = true @@ -875,7 +875,7 @@ func TestLowSequenceHandlingAcrossChannels(t *testing.T) { var options ChangesOptions options.Since = SequenceID{Seq: 0} - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx options.Continuous = true options.Wait = true @@ -931,7 +931,7 @@ func TestLowSequenceHandlingWithAccessGrant(t *testing.T) { var options ChangesOptions options.Since = SequenceID{Seq: 0} - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx options.Continuous = true options.Wait = true @@ -1056,7 +1056,7 @@ func TestChannelQueryCancellation(t *testing.T) { defer changesWg.Done() var options ChangesOptions options.Since = SequenceID{Seq: 0} - options.ChangesCtx = context.Background() + options.ChangesCtx = base.TestCtx(t) options.Continuous = false options.Wait = false options.Limit = 2 // Avoid prepending results in cache, as we don't want second changes to serve results from cache @@ -1075,7 +1075,7 @@ func TestChannelQueryCancellation(t *testing.T) { initialPendingQueries := db.DbStats.Cache().ChannelCachePendingQueries.Value() // Start a second goroutine that should block waiting for the view lock - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) changesWg.Add(1) go func() { defer changesWg.Done() @@ -1147,7 +1147,7 @@ func TestLowSequenceHandlingNoDuplicates(t *testing.T) { var options ChangesOptions options.Since = SequenceID{Seq: 0} - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx defer changesCtxCancel() options.Continuous = true @@ -1239,7 +1239,7 @@ func TestChannelRace(t *testing.T) { var options ChangesOptions options.Since = SequenceID{Seq: 0} - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx options.Continuous = true options.Wait = true @@ -1369,14 +1369,14 @@ func TestChannelCacheSize(t *testing.T) { // Validate that retrieval returns expected sequences require.NoError(t, db.changeCache.waitForSequence(ctx, 750, base.DefaultWaitForSequence)) collection.user, _ = authenticator.GetUser("naomi") - changes, err := collection.GetChanges(ctx, base.SetOf("ABC"), getChangesOptionsWithZeroSeq()) + changes, err := collection.GetChanges(ctx, base.SetOf("ABC"), getChangesOptionsWithZeroSeq(t)) assert.NoError(t, err, "Couldn't GetChanges") assert.Equal(t, 750, len(changes)) // Validate that cache stores the expected number of values collectionID := collection.GetCollectionID() - abcCache, err := db.changeCache.getChannelCache().getSingleChannelCache(channels.NewID("ABC", collectionID)) + abcCache, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, channels.NewID("ABC", collectionID)) require.NoError(t, err) assert.Equal(t, 600, len(abcCache.(*singleChannelCacheImpl).logs)) @@ -1687,7 +1687,7 @@ func TestInitializeEmptyCache(t *testing.T) { } // Issue getChanges for empty channel - changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly()) + changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t)) assert.NoError(t, err, "Couldn't GetChanges") changesCount := len(changes) assert.Equal(t, 0, changesCount) @@ -1705,7 +1705,7 @@ func TestInitializeEmptyCache(t *testing.T) { cacheWaiter.Add(docCount) cacheWaiter.Wait() - changes, err = collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly()) + changes, err = collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t)) assert.NoError(t, err, "Couldn't GetChanges") changesCount = len(changes) assert.Equal(t, 10, changesCount) @@ -1752,7 +1752,7 @@ func TestInitializeCacheUnderLoad(t *testing.T) { // Wait for writes to be in progress, then getChanges for channel zero writesInProgress.Wait() - changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly()) + changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithCtxOnly(t)) require.NoError(t, err, "Couldn't GetChanges") firstChangesCount := len(changes) var lastSeq SequenceID @@ -1763,7 +1763,7 @@ func TestInitializeCacheUnderLoad(t *testing.T) { // Wait for all writes to be cached, then getChanges again cacheWaiter.Wait() - changes, err = collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithSeq(lastSeq)) + changes, err = collection.GetChanges(ctx, channels.BaseSetOf(t, "zero"), getChangesOptionsWithSeq(t, lastSeq)) require.NoError(t, err, "Couldn't GetChanges") secondChangesCount := len(changes) assert.Equal(t, docCount, firstChangesCount+secondChangesCount) @@ -1993,13 +1993,13 @@ func BenchmarkProcessEntry(b *testing.B) { log.Printf("Start error for changeCache: %v", err) b.Fail() } - defer changeCache.Stop() + defer changeCache.Stop(ctx) require.NoError(b, err) if bm.warmCacheCount > 0 { for i := 0; i < bm.warmCacheCount; i++ { channel := channels.NewID(fmt.Sprintf("channel_%d", i), collectionID) - _, err := changeCache.GetChanges(ctx, channel, getChangesOptionsWithZeroSeq()) + _, err := changeCache.GetChanges(ctx, channel, getChangesOptionsWithZeroSeq(b)) if err != nil { log.Printf("GetChanges failed for changeCache: %v", err) b.Fail() @@ -2012,7 +2012,7 @@ func BenchmarkProcessEntry(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { entry := bm.feed.Next() - _ = changeCache.processEntry(entry) + _ = changeCache.processEntry(ctx, entry) } }) } @@ -2227,13 +2227,13 @@ func BenchmarkDocChanged(b *testing.B) { log.Printf("Start error for changeCache: %v", err) b.Fail() } - defer changeCache.Stop() + defer changeCache.Stop(ctx) require.NoError(b, err) if bm.warmCacheCount > 0 { for i := 0; i < bm.warmCacheCount; i++ { channel := channels.NewID(fmt.Sprintf("channel_%d", i), collectionID) - _, err := changeCache.GetChanges(ctx, channel, getChangesOptionsWithZeroSeq()) + _, err := changeCache.GetChanges(ctx, channel, getChangesOptionsWithZeroSeq(b)) if err != nil { log.Printf("GetChanges failed for changeCache: %v", err) b.Fail() diff --git a/db/change_listener.go b/db/change_listener.go index 96b3720715..254377aaec 100644 --- a/db/change_listener.go +++ b/db/change_listener.go @@ -136,7 +136,7 @@ func (listener *changeListener) StartMutationFeed(ctx context.Context, bucket ba } }() defer base.FatalPanicHandler() - defer listener.notifyStopping() + defer listener.notifyStopping(ctx) for event := range listener.tapFeed.Events() { event.TimeReceived = time.Now() listener.ProcessFeedEvent(event) @@ -258,7 +258,7 @@ func (listener *changeListener) notifyKey(ctx context.Context, key string) { } // Changes the counter, notifying waiting clients. -func (listener *changeListener) NotifyCheckForTermination(keys base.Set) { +func (listener *changeListener) NotifyCheckForTermination(ctx context.Context, keys base.Set) { if len(keys) == 0 { return } @@ -272,25 +272,25 @@ func (listener *changeListener) NotifyCheckForTermination(keys base.Set) { listener.terminateCheckCounter = 0 } - base.DebugfCtx(context.TODO(), base.KeyChanges, "Notifying to check for _changes feed termination") + base.DebugfCtx(ctx, base.KeyChanges, "Notifying to check for _changes feed termination") listener.tapNotifier.Broadcast() listener.tapNotifier.L.Unlock() } -func (listener *changeListener) notifyStopping() { +func (listener *changeListener) notifyStopping(ctx context.Context) { listener.tapNotifier.L.Lock() listener.counter = 0 listener.keyCounts = map[string]uint64{} - base.DebugfCtx(context.TODO(), base.KeyChanges, "Notifying that changeListener is stopping") + base.DebugfCtx(ctx, base.KeyChanges, "Notifying that changeListener is stopping") listener.tapNotifier.Broadcast() listener.tapNotifier.L.Unlock() } // Waits until either the counter, or terminateCheckCounter exceeds the given value. Returns the new counters. -func (listener *changeListener) Wait(keys []string, counter uint64, terminateCheckCounter uint64) (uint64, uint64) { +func (listener *changeListener) Wait(ctx context.Context, keys []string, counter uint64, terminateCheckCounter uint64) (uint64, uint64) { listener.tapNotifier.L.Lock() defer listener.tapNotifier.L.Unlock() - base.DebugfCtx(context.TODO(), base.KeyChanges, "No new changes to send to change listener. Waiting for %q's count to pass %d", + base.DebugfCtx(ctx, base.KeyChanges, "No new changes to send to change listener. Waiting for %q's count to pass %d", base.MD(listener.bucketName), counter) for { @@ -379,11 +379,11 @@ func (listener *changeListener) NewWaiterWithChannels(chans channels.Set, user a } // Waits for the changeListener's counter to change from the last time Wait() was called. -func (waiter *ChangeWaiter) Wait() uint32 { +func (waiter *ChangeWaiter) Wait(ctx context.Context) uint32 { lastTerminateCheckCounter := waiter.lastTerminateCheckCounter lastCounter := waiter.lastCounter - waiter.lastCounter, waiter.lastTerminateCheckCounter = waiter.listener.Wait(waiter.keys, waiter.lastCounter, waiter.lastTerminateCheckCounter) + waiter.lastCounter, waiter.lastTerminateCheckCounter = waiter.listener.Wait(ctx, waiter.keys, waiter.lastCounter, waiter.lastTerminateCheckCounter) if waiter.userKeys != nil { waiter.lastUserCount = waiter.listener.CurrentCount(waiter.userKeys) } diff --git a/db/changes.go b/db/changes.go index 952ace30bf..bc6bbe5179 100644 --- a/db/changes.go +++ b/db/changes.go @@ -141,7 +141,7 @@ func (db *DatabaseCollectionWithUser) AddDocToChangeEntryUsingRevCache(ctx conte if err != nil { return err } - entry.Doc, err = rev.As1xBytes(db, nil, nil, false) + entry.Doc, err = rev.As1xBytes(ctx, db, nil, nil, false) return err } @@ -678,14 +678,14 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex // Mark channel set as active, schedule defer col.activeChannels().IncrChannels(collectionID, channelsSince) - defer col.activeChannels().DecrChannels(collectionID, channelsSince) + defer col.activeChannels().DecrChannels(ctx, collectionID, channelsSince) // For a continuous feed, initialise the lateSequenceFeeds that track late-arriving sequences // to the channel caches. if options.Continuous || options.RequestPlusSeq > currentCachedSequence { useLateSequenceFeeds = true lateSequenceFeeds = make(map[channels.ID]*lateSequenceFeed) - defer col.closeLateFeeds(lateSequenceFeeds) + defer col.closeLateFeeds(ctx, lateSequenceFeeds) } // Store incoming low sequence, for potential use by longpoll iterations @@ -705,7 +705,7 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex // lowSequence is used to send composite keys to clients, so that they can obtain any currently // skipped sequences in a future iteration or request. - oldestSkipped := col.changeCache().getOldestSkippedSequence() + oldestSkipped := col.changeCache().getOldestSkippedSequence(ctx) if oldestSkipped > 0 { lowSequence = oldestSkipped - 1 base.InfofCtx(ctx, base.KeyChanges, "%d is the oldest skipped sequence, using stable sequence number of %d for this feed %s", oldestSkipped, lowSequence, base.UD(to)) @@ -736,7 +736,7 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex chanID := channels.NewID(chanName, collectionID) // Obtain a SingleChannelCache instance to use for both normal and late feeds. Required to ensure consistency // if cache is evicted during processing - singleChannelCache, err := col.changeCache().getChannelCache().getSingleChannelCache(chanID) + singleChannelCache, err := col.changeCache().getChannelCache().getSingleChannelCache(ctx, chanID) if err != nil { base.WarnfCtx(ctx, "Unable to obtain channel cache for %s, terminating feed", base.UD(chanName)) change := makeErrorEntry("Channel cache unavailable, terminating feed") @@ -1006,7 +1006,7 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex } col.dbStats().CBLReplicationPull().NumPullReplCaughtUp.Add(1) - waitResponse := changeWaiter.Wait() + waitResponse := changeWaiter.Wait(ctx) col.dbStats().CBLReplicationPull().NumPullReplCaughtUp.Add(-1) if waitResponse == WaiterClosed { @@ -1044,7 +1044,7 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex changedChannels = newChannelsSince.CompareKeys(channelsSince) if len(changedChannels) > 0 { - col.activeChannels().UpdateChanged(collectionID, changedChannels) + col.activeChannels().UpdateChanged(ctx, collectionID, changedChannels) } channelsSince = newChannelsSince } @@ -1052,7 +1052,7 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex // Clean up inactive lateSequenceFeeds (because user has lost access to the channel) for channel, lateFeed := range lateSequenceFeeds { if !lateFeed.active { - col.closeLateFeed(lateFeed) + col.closeLateFeed(ctx, lateFeed) delete(lateSequenceFeeds, channel) } else { lateFeed.active = false @@ -1104,8 +1104,8 @@ func (db *DatabaseCollectionWithUser) GetChanges(ctx context.Context, channels b } // Returns the set of cached log entries for a given channel -func (c *DatabaseCollection) GetChangeLog(channel channels.ID, afterSeq uint64) (entries []*LogEntry, err error) { - return c.changeCache().getChannelCache().GetCachedChanges(channel) +func (c *DatabaseCollection) GetChangeLog(ctx context.Context, channel channels.ID, afterSeq uint64) (entries []*LogEntry, err error) { + return c.changeCache().getChannelCache().GetCachedChanges(ctx, channel) } // WaitForSequenceNotSkipped blocks until the given sequence has been received or skipped by the change cache. @@ -1122,7 +1122,7 @@ func (c *DatabaseCollection) WaitForSequenceNotSkipped(ctx context.Context, sequ // WaitForPendingChanges blocks until the change-cache has caught up with the latest writes to the database. func (c *DatabaseCollection) WaitForPendingChanges(ctx context.Context) error { - lastSequence, err := c.LastSequence() + lastSequence, err := c.LastSequence(ctx) if err != nil { return err } @@ -1208,8 +1208,8 @@ func (db *DatabaseCollectionWithUser) getLateFeed(feedHandler *lateSequenceFeed, } // Closes a single late sequence feed. -func (db *DatabaseCollectionWithUser) closeLateFeed(feedHandler *lateSequenceFeed) { - singleChannelCache, err := db.changeCache().getChannelCache().getSingleChannelCache(feedHandler.channel) +func (db *DatabaseCollectionWithUser) closeLateFeed(ctx context.Context, feedHandler *lateSequenceFeed) { + singleChannelCache, err := db.changeCache().getChannelCache().getSingleChannelCache(ctx, feedHandler.channel) if err != nil || !singleChannelCache.SupportsLateFeed() { return } @@ -1219,9 +1219,9 @@ func (db *DatabaseCollectionWithUser) closeLateFeed(feedHandler *lateSequenceFee } // Closes set of feeds. Invoked on changes termination -func (db *DatabaseCollectionWithUser) closeLateFeeds(feeds map[channels.ID]*lateSequenceFeed) { +func (db *DatabaseCollectionWithUser) closeLateFeeds(ctx context.Context, feeds map[channels.ID]*lateSequenceFeed) { for _, feed := range feeds { - db.closeLateFeed(feed) + db.closeLateFeed(ctx, feed) } } diff --git a/db/changes_test.go b/db/changes_test.go index db28c12415..95dc721c66 100644 --- a/db/changes_test.go +++ b/db/changes_test.go @@ -74,7 +74,7 @@ func TestFilterToAvailableChannels(t *testing.T) { collection.user, err = auth.GetUser("test") require.NoError(t, err) - ch, err := collection.GetChanges(ctx, testCase.accessChans, getChangesOptionsWithZeroSeq()) + ch, err := collection.GetChanges(ctx, testCase.accessChans, getChangesOptionsWithZeroSeq(t)) require.NoError(t, err) require.Len(t, ch, len(testCase.expectedDocsReturned)) @@ -130,7 +130,7 @@ func TestChangesAfterChannelAdded(t *testing.T) { // Check the _changes feed: collection.user, _ = authenticator.GetUser("naomi") - changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq()) + changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq(t)) assert.NoError(t, err, "Couldn't GetChanges") printChanges(changes) require.Len(t, changes, 3) @@ -158,7 +158,7 @@ func TestChangesAfterChannelAdded(t *testing.T) { // Check the _changes feed -- this is to make sure the changeCache properly received // sequence 2 (the user doc) and isn't stuck waiting for it. cacheWaiter.AddAndWait(1) - changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(lastSeq)) + changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(t, lastSeq)) assert.NoError(t, err, "Couldn't GetChanges (2nd)") @@ -167,7 +167,7 @@ func TestChangesAfterChannelAdded(t *testing.T) { assert.Equal(t, []ChangeRev{{"rev": revid}}, changes[0].Changes) // validate from zero - changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq()) + changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq(t)) assert.NoError(t, err, "Couldn't GetChanges") printChanges(changes) @@ -187,18 +187,18 @@ func getLastSeq(changes []*ChangeEntry) SequenceID { } // Makes changes options starting at sequence 0, with a new changes context -func getChangesOptionsWithZeroSeq() ChangesOptions { - return ChangesOptions{Since: SequenceID{Seq: 0}, ChangesCtx: context.Background()} +func getChangesOptionsWithZeroSeq(t testing.TB) ChangesOptions { + return ChangesOptions{Since: SequenceID{Seq: 0}, ChangesCtx: base.TestCtx(t)} } // Makes changes options with a since value of seq and a new changes context -func getChangesOptionsWithSeq(seq SequenceID) ChangesOptions { - return ChangesOptions{Since: seq, ChangesCtx: context.Background()} +func getChangesOptionsWithSeq(t *testing.T, seq SequenceID) ChangesOptions { + return ChangesOptions{Since: seq, ChangesCtx: base.TestCtx(t)} } // Makes changes options a new changes context -func getChangesOptionsWithCtxOnly() ChangesOptions { - return ChangesOptions{ChangesCtx: context.Background()} +func getChangesOptionsWithCtxOnly(t *testing.T) ChangesOptions { + return ChangesOptions{ChangesCtx: base.TestCtx(t)} } func TestDocDeletionFromChannelCoalescedRemoved(t *testing.T) { @@ -229,7 +229,7 @@ func TestDocDeletionFromChannelCoalescedRemoved(t *testing.T) { cacheWaiter.AddAndWait(1) collection.user, _ = authenticator.GetUser("alice") - changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq()) + changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq(t)) require.NoError(t, err, "Couldn't GetChanges") printChanges(changes) assert.Equal(t, 1, len(changes)) @@ -274,7 +274,7 @@ func TestDocDeletionFromChannelCoalescedRemoved(t *testing.T) { // Check the _changes feed -- this is to make sure the changeCache properly received // sequence 3 and isn't stuck waiting for it. cacheWaiter.AddAndWait(1) - changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(lastSeq)) + changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(t, lastSeq)) assert.NoError(t, err, "Couldn't GetChanges (2nd)") @@ -315,7 +315,7 @@ func TestDocDeletionFromChannelCoalesced(t *testing.T) { cacheWaiter.AddAndWait(1) collection.user, _ = authenticator.GetUser("alice") - changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq()) + changes, err := collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithZeroSeq(t)) assert.NoError(t, err, "Couldn't GetChanges") printChanges(changes) @@ -358,7 +358,7 @@ func TestDocDeletionFromChannelCoalesced(t *testing.T) { // sequence 3 (the modified document) and isn't stuck waiting for it. cacheWaiter.AddAndWait(1) - changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(lastSeq)) + changes, err = collection.GetChanges(ctx, base.SetOf("*"), getChangesOptionsWithSeq(t, lastSeq)) assert.NoError(t, err, "Couldn't GetChanges (2nd)") @@ -402,7 +402,7 @@ func TestActiveOnlyCacheUpdate(t *testing.T) { changesOptions := ChangesOptions{ Since: SequenceID{Seq: 0}, ActiveOnly: true, - ChangesCtx: context.Background(), + ChangesCtx: base.TestCtx(t), } initQueryCount := db.DbStats.Cache().ViewQueries.Value() @@ -504,7 +504,7 @@ func BenchmarkChangesFeedDocUnmarshalling(b *testing.B) { // Changes params: POST /pm/_changes?feed=normal&heartbeat=30000&style=all_docs&active_only=true // Changes request of all docs (could also do GetDoc call, but misses other possible things). One shot, .. etc - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(b)) options.ChangesCtx = changesCtx feed, err := collection.MultiChangesFeed(ctx, base.SetOf("*"), options) if err != nil { diff --git a/db/channel_cache.go b/db/channel_cache.go index 06842fac56..a51055e224 100644 --- a/db/channel_cache.go +++ b/db/channel_cache.go @@ -39,7 +39,7 @@ type ChannelCache interface { Init(initialSequence uint64) // Adds an entry to the cache, returns set of channels it was added to - AddToCache(change *LogEntry) []channels.ID + AddToCache(ctx context.Context, change *LogEntry) []channels.ID // Notifies the cache of a principal update. Updates the cache's high sequence AddPrincipal(change *LogEntry) @@ -48,34 +48,36 @@ type ChannelCache interface { AddSkippedSequence(change *LogEntry) // Remove purges the given doc IDs from all channel caches and returns the number of items removed. - Remove(collectionID uint32, docIDs []string, startTime time.Time) (count int) + Remove(ctx context.Context, collectionID uint32, docIDs []string, startTime time.Time) (count int) // Returns set of changes for a given channel, within the bounds specified in options GetChanges(ctx context.Context, ch channels.ID, options ChangesOptions) ([]*LogEntry, error) // Returns the set of all cached data for a given channel (intended for diagnostic usage) - GetCachedChanges(ch channels.ID) ([]*LogEntry, error) + GetCachedChanges(ctx context.Context, ch channels.ID) ([]*LogEntry, error) // Clear reinitializes the cache to an empty state Clear() // Size of the the largest individual channel cache, invoked for stats reporting // // TODO: let the cache manage its own stats internally (maybe take an updateStats call) - MaxCacheSize() int + MaxCacheSize(context.Context) int // Returns the highest cached sequence, used for changes synchronization GetHighCacheSequence() uint64 // Access to individual channel cache - getSingleChannelCache(ch channels.ID) (SingleChannelCache, error) + getSingleChannelCache(ctx context.Context, ch channels.ID) (SingleChannelCache, error) // Access to individual bypass channel cache getBypassChannelCache(ch channels.ID) (SingleChannelCache, error) // Stop stops the channel cache and it's background tasks. - Stop() + Stop(context.Context) } +var _ ChannelCache = &channelCacheImpl{} + // ChannelQueryHandler interface is implemented by databaseContext and databaseCollection. type ChannelQueryHandler interface { getChangesInChannelFromQuery(ctx context.Context, channelName string, startSeq, endSeq uint64, limit int, activeOnly bool) (LogEntries, error) @@ -129,7 +131,7 @@ func newChannelCache(ctx context.Context, dbName string, options ChannelCacheOpt return nil, err } channelCache.backgroundTasks = append(channelCache.backgroundTasks, bgt) - base.DebugfCtx(context.Background(), base.KeyCache, "Initialized channel cache with maxChannels:%d, HWM: %d, LWM: %d", + base.DebugfCtx(ctx, base.KeyCache, "Initialized channel cache with maxChannels:%d, HWM: %d, LWM: %d", channelCache.maxChannels, channelCache.compactHighWatermark, channelCache.compactLowWatermark) return channelCache, nil } @@ -141,12 +143,12 @@ func (c *channelCacheImpl) Clear() { } // Stop stops the channel cache and it's background tasks. -func (c *channelCacheImpl) Stop() { +func (c *channelCacheImpl) Stop(ctx context.Context) { // Signal to terminate channel cache background tasks. close(c.terminator) // Wait for channel cache background tasks to finish. - waitForBGTCompletion(context.TODO(), BGTCompletionMaxWait, c.backgroundTasks, c.dbName) + waitForBGTCompletion(ctx, BGTCompletionMaxWait, c.backgroundTasks, c.dbName) } func (c *channelCacheImpl) Init(initialSequence uint64) { @@ -180,9 +182,9 @@ func (c *channelCacheImpl) updateHighCacheSequence(sequence uint64) { // GetSingleChannelCache will create the cache for the channel if it doesn't exist. If the cache is at // capacity, will return a bypass channel cache. -func (c *channelCacheImpl) getSingleChannelCache(ch channels.ID) (SingleChannelCache, error) { +func (c *channelCacheImpl) getSingleChannelCache(ctx context.Context, ch channels.ID) (SingleChannelCache, error) { - return c.getChannelCache(ch) + return c.getChannelCache(ctx, ch) } func (c *channelCacheImpl) AddPrincipal(change *LogEntry) { @@ -196,7 +198,7 @@ func (c *channelCacheImpl) AddSkippedSequence(change *LogEntry) { // Adds an entry to the appropriate channels' caches, returning the affected channels. lateSequence // flag indicates whether it was a change arriving out of sequence -func (c *channelCacheImpl) AddToCache(change *LogEntry) (updatedChannels []channels.ID) { +func (c *channelCacheImpl) AddToCache(ctx context.Context, change *LogEntry) (updatedChannels []channels.ID) { ch := change.Channels change.Channels = nil // not needed anymore, so free some memory @@ -210,7 +212,7 @@ func (c *channelCacheImpl) AddToCache(change *LogEntry) (updatedChannels []chann // twice) if change.Skipped { c.lateSeqLock.Lock() - base.InfofCtx(context.TODO(), base.KeyChanges, "Acquired late sequence lock in order to cache %d - doc %q / %q", change.Sequence, base.UD(change.DocID), change.RevID) + base.InfofCtx(ctx, base.KeyChanges, "Acquired late sequence lock in order to cache %d - doc %q / %q", change.Sequence, base.UD(change.DocID), change.RevID) defer c.lateSeqLock.Unlock() } @@ -225,9 +227,9 @@ func (c *channelCacheImpl) AddToCache(change *LogEntry) (updatedChannels []chann if channelName == channels.UserStarChannel { explicitStarChannel = true } - channelCache, ok := c.getActiveChannelCache(channels.NewID(channelName, change.CollectionID)) + channelCache, ok := c.getActiveChannelCache(ctx, channels.NewID(channelName, change.CollectionID)) if ok { - channelCache.addToCache(change, removal != nil) + channelCache.addToCache(ctx, change, removal != nil) if change.Skipped { channelCache.AddLateSequence(change) } @@ -238,9 +240,9 @@ func (c *channelCacheImpl) AddToCache(change *LogEntry) (updatedChannels []chann } if EnableStarChannelLog && !explicitStarChannel { - channelCache, ok := c.getActiveChannelCache(channels.NewID(channels.UserStarChannel, change.CollectionID)) + channelCache, ok := c.getActiveChannelCache(ctx, channels.NewID(channels.UserStarChannel, change.CollectionID)) if ok { - channelCache.addToCache(change, false) + channelCache.addToCache(ctx, change, false) if change.Skipped { channelCache.AddLateSequence(change) } @@ -255,21 +257,21 @@ func (c *channelCacheImpl) AddToCache(change *LogEntry) (updatedChannels []chann // Remove purges the given doc IDs from all channel caches and returns the number of items removed. // count will be larger than the input slice if the same document is removed from multiple channel caches. -func (c *channelCacheImpl) Remove(collectionID uint32, docIDs []string, startTime time.Time) (count int) { +func (c *channelCacheImpl) Remove(ctx context.Context, collectionID uint32, docIDs []string, startTime time.Time) (count int) { // Exit early if there's no work to do if len(docIDs) == 0 { return 0 } removeCallback := func(v interface{}) bool { - channelCache := AsSingleChannelCache(v) + channelCache := AsSingleChannelCache(ctx, v) if channelCache == nil { return false } if channelCache.ChannelID().CollectionID != collectionID { return true } - count += channelCache.Remove(collectionID, docIDs, startTime) + count += channelCache.Remove(ctx, collectionID, docIDs, startTime) return true } @@ -280,16 +282,16 @@ func (c *channelCacheImpl) Remove(collectionID uint32, docIDs []string, startTim func (c *channelCacheImpl) GetChanges(ctx context.Context, ch channels.ID, options ChangesOptions) ([]*LogEntry, error) { - cache, err := c.getChannelCache(ch) + cache, err := c.getChannelCache(ctx, ch) if err != nil { return nil, err } return cache.GetChanges(ctx, options) } -func (c *channelCacheImpl) GetCachedChanges(channel channels.ID) ([]*LogEntry, error) { +func (c *channelCacheImpl) GetCachedChanges(ctx context.Context, channel channels.ID) ([]*LogEntry, error) { options := ChangesOptions{Since: SequenceID{Seq: 0}} - cache, err := c.getChannelCache(channel) + cache, err := c.getChannelCache(ctx, channel) if err != nil { return nil, err } @@ -301,7 +303,7 @@ func (c *channelCacheImpl) GetCachedChanges(channel channels.ID) ([]*LogEntry, e func (c *channelCacheImpl) cleanAgedItems(ctx context.Context) error { callback := func(v interface{}) bool { - channelCache := AsSingleChannelCache(v) + channelCache := AsSingleChannelCache(ctx, v) if channelCache == nil { return false } @@ -313,15 +315,15 @@ func (c *channelCacheImpl) cleanAgedItems(ctx context.Context) error { return nil } -func (c *channelCacheImpl) getChannelCache(channel channels.ID) (SingleChannelCache, error) { +func (c *channelCacheImpl) getChannelCache(ctx context.Context, channel channels.ID) (SingleChannelCache, error) { cacheValue, found := c.channelCaches.Get(channel) if found { - return AsSingleChannelCache(cacheValue), nil + return AsSingleChannelCache(ctx, cacheValue), nil } // Attempt to add a singleChannelCache for the channel name. If unsuccessful, return a bypass channel cache - singleChannelCache, ok := c.addChannelCache(channel) + singleChannelCache, ok := c.addChannelCache(ctx, channel) if ok { return singleChannelCache, nil } @@ -354,10 +356,10 @@ func (c *channelCacheImpl) getBypassChannelCache(ch channels.ID) (SingleChannelC // Converts an RangeSafeCollection value to a singleChannelCacheImpl. On type // conversion error, logs a warning and returns nil. -func AsSingleChannelCache(cacheValue interface{}) *singleChannelCacheImpl { +func AsSingleChannelCache(ctx context.Context, cacheValue interface{}) *singleChannelCacheImpl { singleChannelCache, ok := cacheValue.(*singleChannelCacheImpl) if !ok { - base.WarnfCtx(context.Background(), "Unexpected channel cache value type: %T", cacheValue) + base.WarnfCtx(ctx, "Unexpected channel cache value type: %T", cacheValue) return nil } return singleChannelCache @@ -371,7 +373,7 @@ func AsSingleChannelCache(cacheValue interface{}) *singleChannelCacheImpl { // // 4. addChannelCache initializes cache with validFrom=10 and adds to c.channelCaches // // This scenario would result in sequence 11 missing from the cache. Locking seqLock ensures that // // step 3 blocks until step 4 is complete (and so sees the channel as active) -func (c *channelCacheImpl) addChannelCache(channel channels.ID) (*singleChannelCacheImpl, bool) { +func (c *channelCacheImpl) addChannelCache(ctx context.Context, channel channels.ID) (*singleChannelCacheImpl, bool) { // Return nil if the cache at capacity. if c.channelCaches.Length() >= c.maxChannels { @@ -390,14 +392,14 @@ func (c *channelCacheImpl) addChannelCache(channel channels.ID) (*singleChannelC validFrom := c.GetHighCacheSequence() + 1 singleChannelCache := - newChannelCacheWithOptions(queryHandler, channel, validFrom, c.options, c.cacheStats) + newChannelCacheWithOptions(ctx, queryHandler, channel, validFrom, c.options, c.cacheStats) cacheValue, created, cacheSize := c.channelCaches.GetOrInsert(channel, singleChannelCache) c.validFromLock.Unlock() - singleChannelCache = AsSingleChannelCache(cacheValue) + singleChannelCache = AsSingleChannelCache(ctx, cacheValue) if cacheSize > c.compactHighWatermark { - c.startCacheCompaction() + c.startCacheCompaction(ctx) } if created { @@ -408,21 +410,21 @@ func (c *channelCacheImpl) addChannelCache(channel channels.ID) (*singleChannelC return singleChannelCache, true } -func (c *channelCacheImpl) getActiveChannelCache(channel channels.ID) (*singleChannelCacheImpl, bool) { +func (c *channelCacheImpl) getActiveChannelCache(ctx context.Context, channel channels.ID) (*singleChannelCacheImpl, bool) { cacheValue, found := c.channelCaches.Get(channel) if !found { return nil, false } - cache := AsSingleChannelCache(cacheValue) + cache := AsSingleChannelCache(ctx, cacheValue) return cache, cache != nil } -func (c *channelCacheImpl) MaxCacheSize() int { +func (c *channelCacheImpl) MaxCacheSize(ctx context.Context) int { maxCacheSize := 0 callback := func(v interface{}) bool { - channelCache := AsSingleChannelCache(v) + channelCache := AsSingleChannelCache(ctx, v) if channelCache == nil { return false } @@ -442,30 +444,29 @@ func (c *channelCacheImpl) isCompactActive() bool { } // startCacheCompaction starts a goroutine for cache compaction if it's not already running. -func (c *channelCacheImpl) startCacheCompaction() { +func (c *channelCacheImpl) startCacheCompaction(ctx context.Context) { compactNotStarted := c.compactRunning.CompareAndSwap(false, true) if compactNotStarted { - go c.compactChannelCache() + go c.compactChannelCache(ctx) } } // Compact runs until the number of channels in the cache is lower than compactLowWatermark -func (c *channelCacheImpl) compactChannelCache() { +func (c *channelCacheImpl) compactChannelCache(ctx context.Context) { defer c.compactRunning.Set(false) // Increment compact count on start, as timing is updated per loop iteration c.cacheStats.ChannelCacheCompactCount.Add(1) - logCtx := context.TODO() cacheSize := c.channelCaches.Length() - base.InfofCtx(logCtx, base.KeyCache, "Starting channel cache compaction, size %d", cacheSize) + base.InfofCtx(ctx, base.KeyCache, "Starting channel cache compaction, size %d", cacheSize) for { // channelCache close handling compactIterationStart := time.Now() select { case <-c.terminator: - base.DebugfCtx(logCtx, base.KeyCache, "Channel cache compaction stopped due to cache close.") + base.DebugfCtx(ctx, base.KeyCache, "Channel cache compaction stopped due to cache close.") return default: // continue @@ -474,10 +475,10 @@ func (c *channelCacheImpl) compactChannelCache() { // Maintain a target number of items to compact per iteration. Break the list iteration when the target is reached targetEvictCount := cacheSize - c.compactLowWatermark if targetEvictCount <= 0 { - base.InfofCtx(logCtx, base.KeyCache, "Stopping channel cache compaction, size %d", cacheSize) + base.InfofCtx(ctx, base.KeyCache, "Stopping channel cache compaction, size %d", cacheSize) return } - base.TracefCtx(logCtx, base.KeyCache, "Target eviction count: %d (lwm:%d)", targetEvictCount, c.compactLowWatermark) + base.TracefCtx(ctx, base.KeyCache, "Target eviction count: %d (lwm:%d)", targetEvictCount, c.compactLowWatermark) // Iterates through cache entries based on cache size at start of compaction iteration loop. Intentionally // ignores channels added during compaction iteration @@ -491,7 +492,7 @@ func (c *channelCacheImpl) compactChannelCache() { elementCount++ singleChannelCache, ok := elem.Value.(*singleChannelCacheImpl) if !ok { - base.WarnfCtx(logCtx, "Non-cache entry (%T) found in channel cache during compaction - ignoring", elem.Value) + base.WarnfCtx(ctx, "Non-cache entry (%T) found in channel cache during compaction - ignoring", elem.Value) return true } @@ -504,16 +505,16 @@ func (c *channelCacheImpl) compactChannelCache() { // Determine whether NRU channel is active, to establish eviction priority isActive := c.activeChannels.IsActive(singleChannelCache.channelID) if !isActive { - base.TracefCtx(logCtx, base.KeyCache, "Marking inactive cache entry %q for eviction ", base.UD(singleChannelCache.channelID)) + base.TracefCtx(ctx, base.KeyCache, "Marking inactive cache entry %q for eviction ", base.UD(singleChannelCache.channelID)) inactiveEvictionCandidates = append(inactiveEvictionCandidates, elem) } else { - base.TracefCtx(logCtx, base.KeyCache, "Marking NRU cache entry %q for eviction", base.UD(singleChannelCache.channelID)) + base.TracefCtx(ctx, base.KeyCache, "Marking NRU cache entry %q for eviction", base.UD(singleChannelCache.channelID)) nruEvictionCandidates = append(nruEvictionCandidates, elem) } // If we have enough inactive channels to reach targetCount, terminate range if len(inactiveEvictionCandidates) >= targetEvictCount { - base.TracefCtx(logCtx, base.KeyCache, "Eviction count target (%d) reached with inactive channels, proceeding to removal", targetEvictCount) + base.TracefCtx(ctx, base.KeyCache, "Eviction count target (%d) reached with inactive channels, proceeding to removal", targetEvictCount) return false } return true @@ -549,7 +550,7 @@ func (c *channelCacheImpl) compactChannelCache() { // Update eviction stats c.updateEvictionStats(inactiveEvictCount, len(evictionElements), compactIterationStart) - base.TracefCtx(logCtx, base.KeyCache, "Compact iteration complete - eviction count: %d (lwm:%d)", len(evictionElements), c.compactLowWatermark) + base.TracefCtx(ctx, base.KeyCache, "Compact iteration complete - eviction count: %d (lwm:%d)", len(evictionElements), c.compactLowWatermark) } } diff --git a/db/channel_cache_single.go b/db/channel_cache_single.go index 9cc87316c9..d3c797329c 100644 --- a/db/channel_cache_single.go +++ b/db/channel_cache_single.go @@ -134,7 +134,7 @@ func newSingleChannelCache(queryHandler ChannelQueryHandler, channel channels.ID return cache } -func newChannelCacheWithOptions(queryHandler ChannelQueryHandler, channel channels.ID, validFrom uint64, options ChannelCacheOptions, cacheStats *base.CacheStats) *singleChannelCacheImpl { +func newChannelCacheWithOptions(ctx context.Context, queryHandler ChannelQueryHandler, channel channels.ID, validFrom uint64, options ChannelCacheOptions, cacheStats *base.CacheStats) *singleChannelCacheImpl { cache := newSingleChannelCache(queryHandler, channel, validFrom, cacheStats) // Update cache options when present @@ -154,7 +154,7 @@ func newChannelCacheWithOptions(queryHandler ChannelQueryHandler, channel channe cache.options.MaxNumChannels = options.MaxNumChannels } - base.DebugfCtx(context.Background(), base.KeyCache, "Initialized cache for channel %q with min:%v max:%v age:%v, validFrom: %d", + base.DebugfCtx(ctx, base.KeyCache, "Initialized cache for channel %q with min:%v max:%v age:%v, validFrom: %d", base.UD(cache.channelID), cache.options.ChannelCacheMinLength, cache.options.ChannelCacheMaxLength, cache.options.ChannelCacheAge, validFrom) return cache @@ -183,24 +183,24 @@ func (c *singleChannelCacheImpl) SupportsLateFeed() bool { } // Low-level method to add a LogEntry to a single channel's cache. -func (c *singleChannelCacheImpl) addToCache(change *LogEntry, isRemoval bool) { +func (c *singleChannelCacheImpl) addToCache(ctx context.Context, change *LogEntry, isRemoval bool) { c.lock.Lock() defer c.lock.Unlock() if c.wouldBeImmediatelyPruned(change) { - base.InfofCtx(context.TODO(), base.KeyCache, "Not adding change #%d doc %q / %q ==> channel %q, since it will be immediately pruned", + base.InfofCtx(ctx, base.KeyCache, "Not adding change #%d doc %q / %q ==> channel %q, since it will be immediately pruned", change.Sequence, base.UD(change.DocID), change.RevID, base.UD(c.channelID)) return } if !isRemoval { - c._appendChange(change) + c._appendChange(ctx, change) } else { removalChange := *change removalChange.Flags |= channels.Removed - c._appendChange(&removalChange) + c._appendChange(ctx, &removalChange) } - c._pruneCacheLength() + c._pruneCacheLength(ctx) } // If certain conditions are met, it's possible that this change will be added and then @@ -218,7 +218,7 @@ func (c *singleChannelCacheImpl) wouldBeImmediatelyPruned(change *LogEntry) bool } // Remove purges the given doc IDs from the channel cache and returns the number of items removed. -func (c *singleChannelCacheImpl) Remove(collectionID uint32, docIDs []string, startTime time.Time) (count int) { +func (c *singleChannelCacheImpl) Remove(ctx context.Context, collectionID uint32, docIDs []string, startTime time.Time) (count int) { // Exit early if there's no work to do if len(docIDs) == 0 { return 0 @@ -226,7 +226,6 @@ func (c *singleChannelCacheImpl) Remove(collectionID uint32, docIDs []string, st c.lock.Lock() defer c.lock.Unlock() - logCtx := context.TODO() // Build subset of docIDs that we know are present in the cache foundDocs := make(map[string]struct{}, 0) @@ -245,7 +244,7 @@ func (c *singleChannelCacheImpl) Remove(collectionID uint32, docIDs []string, st // Make sure the document we're about to remove is older than the start time of the purge // This is to ensure that resurrected documents do not accidentally get removed. if c.logs[i].TimeReceived.After(startTime) { - base.DebugfCtx(logCtx, base.KeyCache, "Skipping removal of doc %q from cache %q - received after purge", + base.DebugfCtx(ctx, base.KeyCache, "Skipping removal of doc %q from cache %q - received after purge", base.UD(docID), base.UD(c.channelID)) continue } @@ -260,7 +259,7 @@ func (c *singleChannelCacheImpl) Remove(collectionID uint32, docIDs []string, st delete(c.cachedDocIDs, docID) count++ - base.TracefCtx(logCtx, base.KeyCache, "Removed doc %q from cache %q", base.UD(docID), base.UD(c.channelID)) + base.TracefCtx(ctx, base.KeyCache, "Removed doc %q from cache %q", base.UD(docID), base.UD(c.channelID)) } } @@ -268,7 +267,7 @@ func (c *singleChannelCacheImpl) Remove(collectionID uint32, docIDs []string, st } // Internal helper that prunes a single channel's cache. Caller MUST be holding the lock. -func (c *singleChannelCacheImpl) _pruneCacheLength() (pruned int) { +func (c *singleChannelCacheImpl) _pruneCacheLength(ctx context.Context) (pruned int) { // If we are over max length, prune it down to max length if len(c.logs) > c.options.ChannelCacheMaxLength { pruned = len(c.logs) - c.options.ChannelCacheMaxLength @@ -281,7 +280,7 @@ func (c *singleChannelCacheImpl) _pruneCacheLength() (pruned int) { } if pruned > 0 { - base.DebugfCtx(context.TODO(), base.KeyCache, "Pruned %d entries from channel %q", pruned, base.UD(c.channelID)) + base.DebugfCtx(ctx, base.KeyCache, "Pruned %d entries from channel %q", pruned, base.UD(c.channelID)) } return pruned @@ -462,13 +461,13 @@ func (c *singleChannelCacheImpl) _adjustFirstSeq(change *LogEntry) { // Adds an entry to the end of an array of LogEntries. // Any existing entry with the same DocID is removed. -func (c *singleChannelCacheImpl) _appendChange(change *LogEntry) { +func (c *singleChannelCacheImpl) _appendChange(ctx context.Context, change *LogEntry) { log := c.logs end := len(log) - 1 if end >= 0 { if change.Sequence <= log[end].Sequence { - base.DebugfCtx(context.TODO(), base.KeyCache, "LogEntries.appendChange: out-of-order sequence #%d (last is #%d) - handling as insert", + base.DebugfCtx(ctx, base.KeyCache, "LogEntries.appendChange: out-of-order sequence #%d (last is #%d) - handling as insert", change.Sequence, log[end].Sequence) // insert the change in the array, ensuring the docID isn't already present c.insertChange(&c.logs, change) diff --git a/db/channel_cache_single_test.go b/db/channel_cache_single_test.go index bd896eba4f..7b2c883b92 100644 --- a/db/channel_cache_single_test.go +++ b/db/channel_cache_single_test.go @@ -43,35 +43,35 @@ func TestDuplicateDocID(t *testing.T) { assert.NotNil(t, cache) // Add some entries to cache - cache.addToCache(testLogEntry(1, "doc1", "1-a"), false) - cache.addToCache(testLogEntry(2, "doc3", "3-a"), false) - cache.addToCache(testLogEntry(3, "doc5", "5-a"), false) + cache.addToCache(ctx, testLogEntry(1, "doc1", "1-a"), false) + cache.addToCache(ctx, testLogEntry(2, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(3, "doc5", "5-a"), false) - entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 3) assert.True(t, verifyChannelSequences(entries, []uint64{1, 2, 3})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc1", "doc3", "doc5"})) assert.True(t, err == nil) // Add a new revision matching mid-list - cache.addToCache(testLogEntry(4, "doc3", "3-b"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(4, "doc3", "3-b"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 3) assert.True(t, verifyChannelSequences(entries, []uint64{1, 3, 4})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc1", "doc5", "doc3"})) assert.True(t, err == nil) // Add a new revision matching first - cache.addToCache(testLogEntry(5, "doc1", "1-b"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(5, "doc1", "1-b"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 3) assert.True(t, verifyChannelSequences(entries, []uint64{3, 4, 5})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc5", "doc3", "doc1"})) assert.True(t, err == nil) // Add a new revision matching last - cache.addToCache(testLogEntry(6, "doc1", "1-c"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(6, "doc1", "1-c"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 3) assert.True(t, verifyChannelSequences(entries, []uint64{3, 4, 6})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc5", "doc3", "doc1"})) @@ -98,11 +98,11 @@ func TestLateArrivingSequence(t *testing.T) { assert.NotNil(t, cache) // Add some entries to cache - cache.addToCache(testLogEntry(1, "doc1", "1-a"), false) - cache.addToCache(testLogEntry(3, "doc3", "3-a"), false) - cache.addToCache(testLogEntry(5, "doc5", "5-a"), false) + cache.addToCache(ctx, testLogEntry(1, "doc1", "1-a"), false) + cache.addToCache(ctx, testLogEntry(3, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(5, "doc5", "5-a"), false) - entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 3) assert.True(t, verifyChannelSequences(entries, []uint64{1, 3, 5})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc1", "doc3", "doc5"})) @@ -110,8 +110,8 @@ func TestLateArrivingSequence(t *testing.T) { // Add a late-arriving sequence cache.AddLateSequence(testLogEntry(2, "doc2", "2-a")) - cache.addToCache(testLogEntry(2, "doc2", "2-a"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(2, "doc2", "2-a"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) writeEntries(entries) assert.True(t, verifyChannelSequences(entries, []uint64{1, 2, 3, 5})) @@ -139,11 +139,11 @@ func TestLateSequenceAsFirst(t *testing.T) { assert.NotNil(t, cache) // Add some entries to cache - cache.addToCache(testLogEntry(5, "doc1", "1-a"), false) - cache.addToCache(testLogEntry(10, "doc2", "2-a"), false) - cache.addToCache(testLogEntry(15, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(5, "doc1", "1-a"), false) + cache.addToCache(ctx, testLogEntry(10, "doc2", "2-a"), false) + cache.addToCache(ctx, testLogEntry(15, "doc3", "3-a"), false) - entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 3) assert.True(t, verifyChannelSequences(entries, []uint64{5, 10, 15})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc1", "doc2", "doc3"})) @@ -151,8 +151,8 @@ func TestLateSequenceAsFirst(t *testing.T) { // Add a late-arriving sequence cache.AddLateSequence(testLogEntry(3, "doc0", "0-a")) - cache.addToCache(testLogEntry(3, "doc0", "0-a"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(3, "doc0", "0-a"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) writeEntries(entries) assert.True(t, verifyChannelSequences(entries, []uint64{3, 5, 10, 15})) @@ -180,12 +180,12 @@ func TestDuplicateLateArrivingSequence(t *testing.T) { assert.NotNil(t, cache) // Add some entries to cache - cache.addToCache(testLogEntry(10, "doc1", "1-a"), false) - cache.addToCache(testLogEntry(20, "doc2", "2-a"), false) - cache.addToCache(testLogEntry(30, "doc3", "3-a"), false) - cache.addToCache(testLogEntry(40, "doc4", "4-a"), false) + cache.addToCache(ctx, testLogEntry(10, "doc1", "1-a"), false) + cache.addToCache(ctx, testLogEntry(20, "doc2", "2-a"), false) + cache.addToCache(ctx, testLogEntry(30, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(40, "doc4", "4-a"), false) - entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) assert.True(t, verifyChannelSequences(entries, []uint64{10, 20, 30, 40})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc1", "doc2", "doc3", "doc4"})) @@ -193,8 +193,8 @@ func TestDuplicateLateArrivingSequence(t *testing.T) { // Add a late-arriving sequence that should replace earlier sequence cache.AddLateSequence(testLogEntry(25, "doc1", "1-c")) - cache.addToCache(testLogEntry(25, "doc1", "1-c"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(25, "doc1", "1-c"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) writeEntries(entries) assert.True(t, verifyChannelSequences(entries, []uint64{20, 25, 30, 40})) @@ -203,8 +203,8 @@ func TestDuplicateLateArrivingSequence(t *testing.T) { // Add a late-arriving sequence that should be ignored (later sequence exists for that docID) cache.AddLateSequence(testLogEntry(15, "doc1", "1-b")) - cache.addToCache(testLogEntry(15, "doc1", "1-b"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(15, "doc1", "1-b"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) writeEntries(entries) assert.True(t, verifyChannelSequences(entries, []uint64{20, 25, 30, 40})) @@ -213,8 +213,8 @@ func TestDuplicateLateArrivingSequence(t *testing.T) { // Add a late-arriving sequence adjacent to same ID (cache inserts differently) cache.AddLateSequence(testLogEntry(27, "doc1", "1-d")) - cache.addToCache(testLogEntry(27, "doc1", "1-d"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(27, "doc1", "1-d"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) writeEntries(entries) assert.True(t, verifyChannelSequences(entries, []uint64{20, 27, 30, 40})) @@ -223,8 +223,8 @@ func TestDuplicateLateArrivingSequence(t *testing.T) { // Add a late-arriving sequence adjacent to same ID (cache inserts differently) cache.AddLateSequence(testLogEntry(41, "doc4", "4-b")) - cache.addToCache(testLogEntry(41, "doc4", "4-b"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(41, "doc4", "4-b"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) writeEntries(entries) assert.True(t, verifyChannelSequences(entries, []uint64{20, 27, 30, 41})) @@ -233,8 +233,8 @@ func TestDuplicateLateArrivingSequence(t *testing.T) { // Add late arriving that's duplicate of oldest in cache cache.AddLateSequence(testLogEntry(45, "doc2", "2-b")) - cache.addToCache(testLogEntry(45, "doc2", "2-b"), false) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.addToCache(ctx, testLogEntry(45, "doc2", "2-b"), false) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 4) writeEntries(entries) assert.True(t, verifyChannelSequences(entries, []uint64{27, 30, 41, 45})) @@ -272,7 +272,7 @@ func TestPrependChanges(t *testing.T) { assert.Equal(t, 3, numPrepended) // Validate cache - validFrom, cachedChanges := cache.GetCachedChanges(getChangesOptionsWithCtxOnly()) + validFrom, cachedChanges := cache.GetCachedChanges(getChangesOptionsWithCtxOnly(t)) assert.Equal(t, uint64(5), validFrom) require.Len(t, cachedChanges, 3) @@ -283,8 +283,8 @@ func TestPrependChanges(t *testing.T) { require.NoError(t, err) cache = newSingleChannelCache(collection, channels.NewID("PrependPopulatedCache", collection.GetCollectionID()), 0, dbstats.Cache()) cache.validFrom = 13 - cache.addToCache(testLogEntry(14, "doc1", "2-a"), false) - cache.addToCache(testLogEntry(20, "doc2", "3-a"), false) + cache.addToCache(ctx, testLogEntry(14, "doc1", "2-a"), false) + cache.addToCache(ctx, testLogEntry(20, "doc2", "3-a"), false) // Prepend changesToPrepend = LogEntries{ @@ -298,7 +298,7 @@ func TestPrependChanges(t *testing.T) { assert.Equal(t, 2, numPrepended) // Validate cache - validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly()) + validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly(t)) assert.Equal(t, uint64(5), validFrom) require.Len(t, cachedChanges, 4) if len(cachedChanges) == 4 { @@ -313,8 +313,8 @@ func TestPrependChanges(t *testing.T) { } // Write a new revision for a prepended doc to the cache, validate that old entry is removed - cache.addToCache(testLogEntry(24, "doc3", "3-a"), false) - validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly()) + cache.addToCache(ctx, testLogEntry(24, "doc3", "3-a"), false) + validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly(t)) assert.Equal(t, uint64(5), validFrom) require.Len(t, cachedChanges, 4) if len(cachedChanges) == 4 { @@ -330,7 +330,7 @@ func TestPrependChanges(t *testing.T) { // Prepend empty set, validate validFrom update cache.prependChanges(ctx, LogEntries{}, 5, 14) - validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly()) + validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly(t)) assert.Equal(t, uint64(5), validFrom) require.Len(t, cachedChanges, 4) @@ -342,10 +342,10 @@ func TestPrependChanges(t *testing.T) { cache = newSingleChannelCache(collection, channels.NewID("PrependToFillCache", collection.GetCollectionID()), 0, dbstats.Cache()) cache.options.ChannelCacheMaxLength = 5 cache.validFrom = 13 - cache.addToCache(testLogEntry(14, "doc1", "2-a"), false) - cache.addToCache(testLogEntry(20, "doc2", "3-a"), false) - cache.addToCache(testLogEntry(22, "doc3", "3-a"), false) - cache.addToCache(testLogEntry(23, "doc4", "3-a"), false) + cache.addToCache(ctx, testLogEntry(14, "doc1", "2-a"), false) + cache.addToCache(ctx, testLogEntry(20, "doc2", "3-a"), false) + cache.addToCache(ctx, testLogEntry(22, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(23, "doc4", "3-a"), false) // Prepend changes. Only room for one more in cache. doc1 and doc2 should be ignored (already in cache), doc 6 should get cached, doc 5 should be discarded. validFrom should be doc6 (10) changesToPrepend = LogEntries{ @@ -359,7 +359,7 @@ func TestPrependChanges(t *testing.T) { assert.Equal(t, 1, numPrepended) // Validate cache - validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly()) + validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly(t)) assert.Equal(t, uint64(10), validFrom) require.Len(t, cachedChanges, 5) if len(cachedChanges) == 5 { @@ -382,10 +382,10 @@ func TestPrependChanges(t *testing.T) { require.NoError(t, err) cache = newSingleChannelCache(collection, channels.NewID("PrependDuplicatesOnly", collection.GetCollectionID()), 0, dbstats.Cache()) cache.validFrom = 13 - cache.addToCache(testLogEntry(14, "doc1", "2-a"), false) - cache.addToCache(testLogEntry(20, "doc2", "3-a"), false) - cache.addToCache(testLogEntry(22, "doc3", "3-a"), false) - cache.addToCache(testLogEntry(23, "doc4", "3-a"), false) + cache.addToCache(ctx, testLogEntry(14, "doc1", "2-a"), false) + cache.addToCache(ctx, testLogEntry(20, "doc2", "3-a"), false) + cache.addToCache(ctx, testLogEntry(22, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(23, "doc4", "3-a"), false) changesToPrepend = LogEntries{ testLogEntry(8, "doc2", "2-a"), @@ -395,7 +395,7 @@ func TestPrependChanges(t *testing.T) { } numPrepended = cache.prependChanges(ctx, changesToPrepend, 5, 14) assert.Equal(t, 0, numPrepended) - validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly()) + validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly(t)) assert.Equal(t, uint64(5), validFrom) require.Len(t, cachedChanges, 4) if len(cachedChanges) == 5 { @@ -417,11 +417,11 @@ func TestPrependChanges(t *testing.T) { cache = newSingleChannelCache(collection, channels.NewID("PrependFullCache", collection.GetCollectionID()), 0, dbstats.Cache()) cache.options.ChannelCacheMaxLength = 5 cache.validFrom = 13 - cache.addToCache(testLogEntry(14, "doc1", "2-a"), false) - cache.addToCache(testLogEntry(20, "doc2", "3-a"), false) - cache.addToCache(testLogEntry(22, "doc3", "3-a"), false) - cache.addToCache(testLogEntry(23, "doc4", "3-a"), false) - cache.addToCache(testLogEntry(25, "doc5", "3-a"), false) + cache.addToCache(ctx, testLogEntry(14, "doc1", "2-a"), false) + cache.addToCache(ctx, testLogEntry(20, "doc2", "3-a"), false) + cache.addToCache(ctx, testLogEntry(22, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(23, "doc4", "3-a"), false) + cache.addToCache(ctx, testLogEntry(25, "doc5", "3-a"), false) // Prepend changes, no room for in cache. changesToPrepend = LogEntries{ @@ -435,7 +435,7 @@ func TestPrependChanges(t *testing.T) { assert.Equal(t, 0, numPrepended) // Validate cache - validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly()) + validFrom, cachedChanges = cache.GetCachedChanges(getChangesOptionsWithCtxOnly(t)) assert.Equal(t, uint64(13), validFrom) require.Len(t, cachedChanges, 5) if len(cachedChanges) == 5 { @@ -471,19 +471,19 @@ func TestChannelCacheRemove(t *testing.T) { cache := newSingleChannelCache(collection, channels.NewID("Test1", collectionID), 0, dbstats.Cache()) // Add some entries to cache - cache.addToCache(testLogEntry(1, "doc1", "1-a"), false) - cache.addToCache(testLogEntry(2, "doc3", "3-a"), false) - cache.addToCache(testLogEntry(3, "doc5", "5-a"), false) + cache.addToCache(ctx, testLogEntry(1, "doc1", "1-a"), false) + cache.addToCache(ctx, testLogEntry(2, "doc3", "3-a"), false) + cache.addToCache(ctx, testLogEntry(3, "doc5", "5-a"), false) - entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + entries, err := cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 3) assert.True(t, verifyChannelSequences(entries, []uint64{1, 2, 3})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc1", "doc3", "doc5"})) assert.True(t, err == nil) // Now remove doc1 - cache.Remove(collectionID, []string{"doc1"}, time.Now()) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.Remove(ctx, collectionID, []string{"doc1"}, time.Now()) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 2) assert.True(t, verifyChannelSequences(entries, []uint64{2, 3})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc3", "doc5"})) @@ -492,8 +492,8 @@ func TestChannelCacheRemove(t *testing.T) { // Try to remove doc5 with a startTime before it was added to ensure it's not removed // This will print a debug level log: // [DBG] Cache+: Skipping removal of doc "doc5" from cache "Test1" - received after purge - cache.Remove(collectionID, []string{"doc5"}, time.Now().Add(-time.Second*5)) - entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq()) + cache.Remove(ctx, collectionID, []string{"doc5"}, time.Now().Add(-time.Second*5)) + entries, err = cache.GetChanges(ctx, getChangesOptionsWithZeroSeq(t)) require.Len(t, entries, 2) assert.True(t, verifyChannelSequences(entries, []uint64{2, 3})) assert.True(t, verifyChannelDocIDs(entries, []string{"doc3", "doc5"})) @@ -520,9 +520,9 @@ func TestChannelCacheStats(t *testing.T) { cache := newSingleChannelCache(collection, channels.NewID("Test1", collectionID), 0, testStats) // Add some entries to cache - cache.addToCache(testLogEntry(1, "doc1", "1-a"), false) - cache.addToCache(testLogEntry(2, "doc2", "1-a"), false) - cache.addToCache(testLogEntry(3, "doc3", "1-a"), false) + cache.addToCache(ctx, testLogEntry(1, "doc1", "1-a"), false) + cache.addToCache(ctx, testLogEntry(2, "doc2", "1-a"), false) + cache.addToCache(ctx, testLogEntry(3, "doc3", "1-a"), false) active, tombstones, removals := getCacheUtilization(testStats) assert.Equal(t, 3, active) @@ -530,22 +530,22 @@ func TestChannelCacheStats(t *testing.T) { assert.Equal(t, 0, removals) // Update keys already present in the cache, shouldn't modify utilization - cache.addToCache(testLogEntry(4, "doc1", "2-a"), false) - cache.addToCache(testLogEntry(5, "doc2", "2-a"), false) + cache.addToCache(ctx, testLogEntry(4, "doc1", "2-a"), false) + cache.addToCache(ctx, testLogEntry(5, "doc2", "2-a"), false) active, tombstones, removals = getCacheUtilization(testStats) assert.Equal(t, 3, active) assert.Equal(t, 0, tombstones) assert.Equal(t, 0, removals) // Add a removal rev for a doc not previously in the cache - cache.addToCache(testLogEntry(6, "doc4", "2-a"), true) + cache.addToCache(ctx, testLogEntry(6, "doc4", "2-a"), true) active, tombstones, removals = getCacheUtilization(testStats) assert.Equal(t, 3, active) assert.Equal(t, 0, tombstones) assert.Equal(t, 1, removals) // Add a removal rev for a doc previously in the cache - cache.addToCache(testLogEntry(7, "doc1", "3-a"), true) + cache.addToCache(ctx, testLogEntry(7, "doc1", "3-a"), true) active, tombstones, removals = getCacheUtilization(testStats) assert.Equal(t, 2, active) assert.Equal(t, 0, tombstones) @@ -554,7 +554,7 @@ func TestChannelCacheStats(t *testing.T) { // Add a new tombstone to the cache tombstone := testLogEntry(8, "doc5", "2-a") tombstone.SetDeleted() - cache.addToCache(tombstone, false) + cache.addToCache(ctx, tombstone, false) active, tombstones, removals = getCacheUtilization(testStats) assert.Equal(t, 2, active) assert.Equal(t, 1, tombstones) @@ -563,7 +563,7 @@ func TestChannelCacheStats(t *testing.T) { // Add a tombstone that's also a removal. Should only be tracked as removal tombstone = testLogEntry(9, "doc6", "2-a") tombstone.SetDeleted() - cache.addToCache(tombstone, true) + cache.addToCache(ctx, tombstone, true) active, tombstones, removals = getCacheUtilization(testStats) assert.Equal(t, 2, active) assert.Equal(t, 1, tombstones) @@ -572,7 +572,7 @@ func TestChannelCacheStats(t *testing.T) { // Tombstone a document id already present in the cache as an active revision tombstone = testLogEntry(10, "doc2", "3-a") tombstone.SetDeleted() - cache.addToCache(tombstone, false) + cache.addToCache(ctx, tombstone, false) active, tombstones, removals = getCacheUtilization(testStats) assert.Equal(t, 1, active) assert.Equal(t, 2, tombstones) @@ -600,16 +600,16 @@ func TestChannelCacheStatsOnPrune(t *testing.T) { cache.options.ChannelCacheMaxLength = 5 // Add more than ChannelCacheMaxLength entries to cache - cache.addToCache(testLogEntry(1, "doc1", "1-a"), false) - cache.addToCache(testLogEntry(2, "doc2", "1-a"), true) - cache.addToCache(testLogEntry(3, "doc3", "1-a"), false) - cache.addToCache(testLogEntry(4, "doc4", "1-a"), true) - cache.addToCache(testLogEntry(5, "doc5", "1-a"), false) - cache.addToCache(testLogEntry(6, "doc6", "1-a"), true) - cache.addToCache(testLogEntry(7, "doc7", "1-a"), false) - cache.addToCache(testLogEntry(8, "doc8", "1-a"), true) - cache.addToCache(testLogEntry(9, "doc9", "1-a"), false) - cache.addToCache(testLogEntry(10, "doc10", "1-a"), true) + cache.addToCache(ctx, testLogEntry(1, "doc1", "1-a"), false) + cache.addToCache(ctx, testLogEntry(2, "doc2", "1-a"), true) + cache.addToCache(ctx, testLogEntry(3, "doc3", "1-a"), false) + cache.addToCache(ctx, testLogEntry(4, "doc4", "1-a"), true) + cache.addToCache(ctx, testLogEntry(5, "doc5", "1-a"), false) + cache.addToCache(ctx, testLogEntry(6, "doc6", "1-a"), true) + cache.addToCache(ctx, testLogEntry(7, "doc7", "1-a"), false) + cache.addToCache(ctx, testLogEntry(8, "doc8", "1-a"), true) + cache.addToCache(ctx, testLogEntry(9, "doc9", "1-a"), false) + cache.addToCache(ctx, testLogEntry(10, "doc10", "1-a"), true) active, tombstones, removals := getCacheUtilization(testStats) assert.Equal(t, 2, active) @@ -639,18 +639,18 @@ func TestChannelCacheStatsOnPrepend(t *testing.T) { cache.options.ChannelCacheMaxLength = 15 // Add 9 entries to cache, 3 of each type - cache.addToCache(testLogEntry(100, "active1", "2-a"), false) - cache.addToCache(testLogEntry(102, "active2", "2-a"), false) - cache.addToCache(testLogEntry(104, "removal1", "2-a"), true) - cache.addToCache(et(106, "tombstone1", "2-a"), false) - cache.addToCache(testLogEntry(107, "removal2", "2-a"), true) - cache.addToCache(testLogEntry(108, "removal3", "2-a"), true) - cache.addToCache(et(110, "tombstone2", "2-a"), false) - cache.addToCache(et(111, "tombstone3", "2-a"), false) - cache.addToCache(testLogEntry(112, "active3", "2-a"), false) + cache.addToCache(ctx, testLogEntry(100, "active1", "2-a"), false) + cache.addToCache(ctx, testLogEntry(102, "active2", "2-a"), false) + cache.addToCache(ctx, testLogEntry(104, "removal1", "2-a"), true) + cache.addToCache(ctx, et(106, "tombstone1", "2-a"), false) + cache.addToCache(ctx, testLogEntry(107, "removal2", "2-a"), true) + cache.addToCache(ctx, testLogEntry(108, "removal3", "2-a"), true) + cache.addToCache(ctx, et(110, "tombstone2", "2-a"), false) + cache.addToCache(ctx, et(111, "tombstone3", "2-a"), false) + cache.addToCache(ctx, testLogEntry(112, "active3", "2-a"), false) active, tombstones, removals := getCacheUtilization(testStats) - assert.Equal(t, 3, active) + require.Equal(t, 3, active) assert.Equal(t, 3, tombstones) assert.Equal(t, 3, removals) @@ -707,11 +707,11 @@ func TestBypassSingleChannelCache(t *testing.T) { queryHandler: queryHandler, } - entries, err := bypassCache.GetChanges(base.TestCtx(t), getChangesOptionsWithZeroSeq()) + entries, err := bypassCache.GetChanges(base.TestCtx(t), getChangesOptionsWithZeroSeq(t)) assert.NoError(t, err) require.Len(t, entries, 10) - validFrom, cachedEntries := bypassCache.GetCachedChanges(getChangesOptionsWithZeroSeq()) + validFrom, cachedEntries := bypassCache.GetCachedChanges(getChangesOptionsWithZeroSeq(t)) assert.Equal(t, uint64(math.MaxUint64), validFrom) require.Len(t, cachedEntries, 0) } @@ -739,7 +739,7 @@ func BenchmarkChannelCacheUniqueDocs_Ordered(b *testing.B) { } b.ResetTimer() for i := 0; i < b.N; i++ { - cache.addToCache(testLogEntry(uint64(i), docIDs[i], "1-a"), false) + cache.addToCache(ctx, testLogEntry(uint64(i), docIDs[i], "1-a"), false) } } @@ -764,7 +764,7 @@ func BenchmarkChannelCacheRepeatedDocs5(b *testing.B) { docIDs, revStrings := generateDocs(5.0, b.N) b.ResetTimer() for i := 0; i < b.N; i++ { - cache.addToCache(testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) + cache.addToCache(ctx, testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) } } @@ -788,7 +788,7 @@ func BenchmarkChannelCacheRepeatedDocs20(b *testing.B) { docIDs, revStrings := generateDocs(20.0, b.N) b.ResetTimer() for i := 0; i < b.N; i++ { - cache.addToCache(testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) + cache.addToCache(ctx, testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) } } @@ -812,7 +812,7 @@ func BenchmarkChannelCacheRepeatedDocs50(b *testing.B) { docIDs, revStrings := generateDocs(50.0, b.N) b.ResetTimer() for i := 0; i < b.N; i++ { - cache.addToCache(testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) + cache.addToCache(ctx, testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) } } @@ -836,7 +836,7 @@ func BenchmarkChannelCacheRepeatedDocs80(b *testing.B) { docIDs, revStrings := generateDocs(80.0, b.N) b.ResetTimer() for i := 0; i < b.N; i++ { - cache.addToCache(testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) + cache.addToCache(ctx, testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) } } @@ -860,7 +860,7 @@ func BenchmarkChannelCacheRepeatedDocs95(b *testing.B) { docIDs, revStrings := generateDocs(95.0, b.N) b.ResetTimer() for i := 0; i < b.N; i++ { - cache.addToCache(testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) + cache.addToCache(ctx, testLogEntry(uint64(i), docIDs[i], revStrings[i]), false) } } @@ -896,7 +896,7 @@ func BenchmarkChannelCacheUniqueDocs_Unordered(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - cache.addToCache(docs[i], false) + cache.addToCache(ctx, docs[i], false) } } diff --git a/db/channel_cache_test.go b/db/channel_cache_test.go index 72d61af1b3..4637539848 100644 --- a/db/channel_cache_test.go +++ b/db/channel_cache_test.go @@ -32,22 +32,22 @@ func TestChannelCacheMaxSize(t *testing.T) { collectionID := GetSingleDatabaseCollection(t, db.DatabaseContext).GetCollectionID() // Make channels active - _, err := cache.GetChanges(ctx, channels.NewID("TestA", collectionID), getChangesOptionsWithCtxOnly()) + _, err := cache.GetChanges(ctx, channels.NewID("TestA", collectionID), getChangesOptionsWithCtxOnly(t)) require.NoError(t, err) - _, err = cache.GetChanges(ctx, channels.NewID("TestB", collectionID), getChangesOptionsWithCtxOnly()) + _, err = cache.GetChanges(ctx, channels.NewID("TestB", collectionID), getChangesOptionsWithCtxOnly(t)) require.NoError(t, err) - _, err = cache.GetChanges(ctx, channels.NewID("TestC", collectionID), getChangesOptionsWithCtxOnly()) + _, err = cache.GetChanges(ctx, channels.NewID("TestC", collectionID), getChangesOptionsWithCtxOnly(t)) require.NoError(t, err) - _, err = cache.GetChanges(ctx, channels.NewID("TestD", collectionID), getChangesOptionsWithCtxOnly()) + _, err = cache.GetChanges(ctx, channels.NewID("TestD", collectionID), getChangesOptionsWithCtxOnly(t)) require.NoError(t, err) // Add some entries to caches, leaving some empty caches - cache.AddToCache(logEntry(1, "doc1", "1-a", []string{"TestB", "TestC", "TestD"}, collectionID)) - cache.AddToCache(logEntry(2, "doc2", "1-a", []string{"TestB", "TestC", "TestD"}, collectionID)) - cache.AddToCache(logEntry(3, "doc3", "1-a", []string{"TestB", "TestC", "TestD"}, collectionID)) - cache.AddToCache(logEntry(4, "doc4", "1-a", []string{"TestC"}, collectionID)) + cache.AddToCache(ctx, logEntry(1, "doc1", "1-a", []string{"TestB", "TestC", "TestD"}, collectionID)) + cache.AddToCache(ctx, logEntry(2, "doc2", "1-a", []string{"TestB", "TestC", "TestD"}, collectionID)) + cache.AddToCache(ctx, logEntry(3, "doc3", "1-a", []string{"TestB", "TestC", "TestD"}, collectionID)) + cache.AddToCache(ctx, logEntry(4, "doc4", "1-a", []string{"TestC"}, collectionID)) - db.UpdateCalculatedStats() + db.UpdateCalculatedStats(ctx) maxEntries := db.DbStats.Cache().ChannelCacheMaxEntries.Value() assert.Equal(t, 4, int(maxEntries)) @@ -84,22 +84,23 @@ func TestChannelCacheSimpleCompact(t *testing.T) { testStats := dbstats.Cache() activeChannelStat := &base.SgwIntStat{} activeChannels := channels.NewActiveChannels(activeChannelStat) + ctx := base.TestCtx(t) cache, err := newChannelCache(base.TestCtx(t), "testDb", options, testQueryHandlerFactory, activeChannels, testStats) require.NoError(t, err, "Background task error whilst creating channel cache") - defer cache.Stop() + defer cache.Stop(ctx) require.NoError(t, err) // Add 16 channels to the cache. Shouldn't trigger compaction (hwm is not exceeded) for i := 1; i <= 16; i++ { channelName := fmt.Sprintf("chan_%d", i) - cache.addChannelCache(channels.NewID(channelName, base.DefaultCollectionID)) + cache.addChannelCache(ctx, channels.NewID(channelName, base.DefaultCollectionID)) } // Validate cache size assert.Equal(t, 16, cache.channelCaches.Length()) // Add another channel to cache - cache.addChannelCache(channels.NewID("chan_17", base.DefaultCollectionID)) + cache.addChannelCache(ctx, channels.NewID("chan_17", base.DefaultCollectionID)) assert.True(t, waitForCompaction(cache), "Compaction didn't complete in expected time") @@ -125,15 +126,17 @@ func TestChannelCacheCompactInactiveChannels(t *testing.T) { testStats := dbstats.Cache() activeChannelStat := &base.SgwIntStat{} activeChannels := channels.NewActiveChannels(activeChannelStat) + + ctx := base.TestCtx(t) cache, err := newChannelCache(base.TestCtx(t), "testDb", options, testQueryHandlerFactory, activeChannels, testStats) require.NoError(t, err, "Background task error whilst creating channel cache") - defer cache.Stop() + defer cache.Stop(ctx) // Add 16 channels to the cache. Mark odd channels as active, even channels as inactive. // Shouldn't trigger compaction (hwm is not exceeded) for i := 1; i <= 18; i++ { channel := channels.NewID(fmt.Sprintf("chan_%d", i), base.DefaultCollectionID) - cache.addChannelCache(channel) + cache.addChannelCache(ctx, channel) if i%2 == 1 { log.Printf("Marking channel %q as active", channel) activeChannels.IncrChannel(channel) @@ -144,7 +147,7 @@ func TestChannelCacheCompactInactiveChannels(t *testing.T) { log.Printf("adding 19th element to cache...") // Add another channel to cache, should trigger compaction - cache.addChannelCache(channels.NewID("chan_19", base.DefaultCollectionID)) + cache.addChannelCache(ctx, channels.NewID("chan_19", base.DefaultCollectionID)) assert.True(t, waitForCompaction(cache), "Compaction didn't complete in expected time") @@ -183,15 +186,16 @@ func TestChannelCacheCompactNRU(t *testing.T) { testStats := dbstats.Cache() activeChannelStat := &base.SgwIntStat{} activeChannels := channels.NewActiveChannels(activeChannelStat) + ctx := base.TestCtx(t) cache, err := newChannelCache(base.TestCtx(t), "testDb", options, testQueryHandlerFactory, activeChannels, testStats) require.NoError(t, err, "Background task error whilst creating channel cache") - defer cache.Stop() + defer cache.Stop(ctx) // Add 18 channels to the cache. Mark channels 1-10 as active // Shouldn't trigger compaction (hwm is not exceeded) for i := 1; i <= 18; i++ { channel := channels.NewID(fmt.Sprintf("chan_%d", i), base.DefaultCollectionID) - cache.addChannelCache(channel) + cache.addChannelCache(ctx, channel) if i <= 10 { log.Printf("Marking channel %q as active", channel) activeChannels.IncrChannel(channel) @@ -201,7 +205,7 @@ func TestChannelCacheCompactNRU(t *testing.T) { assert.Equal(t, 18, cache.channelCaches.Length()) // Add another channel to cache, should trigger compaction - cache.addChannelCache(channels.NewID("chan_19", base.DefaultCollectionID)) + cache.addChannelCache(ctx, channels.NewID("chan_19", base.DefaultCollectionID)) assert.True(t, waitForCompaction(cache), "Compaction didn't complete in expected time") // Expect channels 1-10, 11-15 to be evicted, and all to be marked as NRU during compaction @@ -223,7 +227,7 @@ func TestChannelCacheCompactNRU(t *testing.T) { channel := channels.NewID(fmt.Sprintf("chan_%d", i), base.DefaultCollectionID) cacheElement, isCached := cache.channelCaches.Get(channel) assert.True(t, isCached, fmt.Sprintf("Expected %s to be cached during recently used update", channel)) - AsSingleChannelCache(cacheElement).recentlyUsed.Set(true) + AsSingleChannelCache(ctx, cacheElement).recentlyUsed.Set(true) } // Add new channels to trigger compaction. At start of compaction, expect: @@ -234,9 +238,9 @@ func TestChannelCacheCompactNRU(t *testing.T) { channel := channels.NewID(fmt.Sprintf("chan_%d", i), base.DefaultCollectionID) if i <= 10 { log.Printf("Marking channel %q as inactive", channel) - activeChannels.DecrChannel(channel) + activeChannels.DecrChannel(ctx, channel) } else { - cache.addChannelCache(channel) + cache.addChannelCache(ctx, channel) } } @@ -280,9 +284,10 @@ func TestChannelCacheHighLoadCacheHit(t *testing.T) { queryHandler := &testQueryHandler{} activeChannelStat := &base.SgwIntStat{} activeChannels := channels.NewActiveChannels(activeChannelStat) - cache, err := newChannelCache(base.TestCtx(t), "testDb", options, queryHandler.asFactory, activeChannels, testStats) + ctx := base.TestCtx(t) + cache, err := newChannelCache(ctx, "testDb", options, queryHandler.asFactory, activeChannels, testStats) require.NoError(t, err, "Background task error whilst creating channel cache") - defer cache.Stop() + defer cache.Stop(ctx) channelCount := 90 // define channel set @@ -298,7 +303,7 @@ func TestChannelCacheHighLoadCacheHit(t *testing.T) { // Send entry to the cache. Don't reuse queryEntry here, as AddToCache strips out the channels property logEntry := testLogEntryForChannels(1, channelNames) - cache.AddToCache(logEntry) + cache.AddToCache(ctx, logEntry) workerCount := 25 getChangesCount := 400 @@ -312,7 +317,7 @@ func TestChannelCacheHighLoadCacheHit(t *testing.T) { for i := 0; i < getChangesCount; i++ { channelNumber := rand.Intn(channelCount) + 1 channel := channels.NewID(fmt.Sprintf("chan_%d", channelNumber), base.DefaultCollectionID) - options := getChangesOptionsWithCtxOnly() + options := getChangesOptionsWithCtxOnly(t) changes, err := cache.GetChanges(base.TestCtx(t), channel, options) if len(changes) == 1 { changesSuccessCount++ @@ -354,9 +359,10 @@ func TestChannelCacheHighLoadCacheMiss(t *testing.T) { queryHandler := &testQueryHandler{} activeChannelStat := &base.SgwIntStat{} activeChannels := channels.NewActiveChannels(activeChannelStat) - cache, err := newChannelCache(base.TestCtx(t), "testDb", options, queryHandler.asFactory, activeChannels, testStats) + ctx := base.TestCtx(t) + cache, err := newChannelCache(ctx, "testDb", options, queryHandler.asFactory, activeChannels, testStats) require.NoError(t, err, "Background task error whilst creating channel cache") - defer cache.Stop() + defer cache.Stop(ctx) channelCount := 200 // define channel set @@ -372,7 +378,7 @@ func TestChannelCacheHighLoadCacheMiss(t *testing.T) { // Send entry to the cache. Don't reuse queryEntry here, as AddToCache strips out the channels property logEntry := testLogEntryForChannels(1, channelNames) - cache.AddToCache(logEntry) + cache.AddToCache(ctx, logEntry) workerCount := 25 getChangesCount := 400 @@ -386,7 +392,7 @@ func TestChannelCacheHighLoadCacheMiss(t *testing.T) { for i := 0; i < getChangesCount; i++ { channelNumber := rand.Intn(channelCount) + 1 channel := channels.NewID(fmt.Sprintf("chan_%d", channelNumber), base.DefaultCollectionID) - options := getChangesOptionsWithCtxOnly() + options := getChangesOptionsWithCtxOnly(t) changes, err := cache.GetChanges(base.TestCtx(t), channel, options) if len(changes) == 1 { changesSuccessCount++ @@ -423,9 +429,10 @@ func TestChannelCacheBypass(t *testing.T) { queryHandler := &testQueryHandler{} activeChannelStat := &base.SgwIntStat{} activeChannels := channels.NewActiveChannels(activeChannelStat) - cache, err := newChannelCache(base.TestCtx(t), "testDb", options, queryHandler.asFactory, activeChannels, testStats) + ctx := base.TestCtx(t) + cache, err := newChannelCache(ctx, "testDb", options, queryHandler.asFactory, activeChannels, testStats) require.NoError(t, err, "Background task error whilst creating channel cache") - defer cache.Stop() + defer cache.Stop(ctx) channelCount := 100 // define channel set @@ -441,12 +448,12 @@ func TestChannelCacheBypass(t *testing.T) { // Send entry to the cache. Don't reuse queryEntry here, as AddToCache strips out the channels property logEntry := testLogEntryForChannels(1, channelNames) - cache.AddToCache(logEntry) + cache.AddToCache(ctx, logEntry) // Issue queries for all channels. First 20 should end up in the cache, remaining 80 should trigger bypass for c := 1; c <= channelCount; c++ { channel := channels.NewID(fmt.Sprintf("chan_%d", c), base.DefaultCollectionID) - options := getChangesOptionsWithCtxOnly() + options := getChangesOptionsWithCtxOnly(t) changes, err := cache.GetChanges(base.TestCtx(t), channel, options) assert.NoError(t, err, fmt.Sprintf("Error getting changes for channel %q", channel)) assert.True(t, len(changes) == 1, "Expected one change per channel") diff --git a/db/crud.go b/db/crud.go index 296acb8c87..f3484ff15d 100644 --- a/db/crud.go +++ b/db/crud.go @@ -60,7 +60,7 @@ func (c *DatabaseCollection) GetDocumentWithRaw(ctx context.Context, docid strin return nil, nil, base.HTTPErrorf(400, "Invalid doc ID") } if c.UseXattrs() { - doc, rawBucketDoc, err = c.GetDocWithXattr(key, unmarshalLevel) + doc, rawBucketDoc, err = c.GetDocWithXattr(ctx, key, unmarshalLevel) if err != nil { return nil, nil, err } @@ -98,7 +98,7 @@ func (c *DatabaseCollection) GetDocumentWithRaw(ctx context.Context, docid strin if !doc.HasValidSyncData() { // Check whether doc has been upgraded to use xattrs - upgradeDoc, _ := c.checkForUpgrade(docid, unmarshalLevel) + upgradeDoc, _ := c.checkForUpgrade(ctx, docid, unmarshalLevel) if upgradeDoc == nil { return nil, nil, base.HTTPErrorf(404, "Not imported") } @@ -114,7 +114,7 @@ func (c *DatabaseCollection) GetDocumentWithRaw(ctx context.Context, docid strin return doc, rawBucketDoc, nil } -func (c *DatabaseCollection) GetDocWithXattr(key string, unmarshalLevel DocumentUnmarshalLevel) (doc *Document, rawBucketDoc *sgbucket.BucketDocument, err error) { +func (c *DatabaseCollection) GetDocWithXattr(ctx context.Context, key string, unmarshalLevel DocumentUnmarshalLevel) (doc *Document, rawBucketDoc *sgbucket.BucketDocument, err error) { rawBucketDoc = &sgbucket.BucketDocument{} var getErr error rawBucketDoc.Cas, getErr = c.dataStore.GetWithXattr(key, base.SyncXattrName, c.userXattrKey(), &rawBucketDoc.Body, &rawBucketDoc.Xattr, &rawBucketDoc.UserXattr) @@ -123,7 +123,7 @@ func (c *DatabaseCollection) GetDocWithXattr(key string, unmarshalLevel Document } var unmarshalErr error - doc, unmarshalErr = unmarshalDocumentWithXattr(key, rawBucketDoc.Body, rawBucketDoc.Xattr, rawBucketDoc.UserXattr, rawBucketDoc.Cas, unmarshalLevel) + doc, unmarshalErr = unmarshalDocumentWithXattr(ctx, key, rawBucketDoc.Body, rawBucketDoc.Xattr, rawBucketDoc.UserXattr, rawBucketDoc.Cas, unmarshalLevel) if unmarshalErr != nil { return nil, nil, unmarshalErr } @@ -150,7 +150,7 @@ func (c *DatabaseCollection) GetDocSyncData(ctx context.Context, docid string) ( } // Unmarshal xattr only - doc, unmarshalErr := unmarshalDocumentWithXattr(docid, nil, rawXattr, rawUserXattr, cas, DocUnmarshalSync) + doc, unmarshalErr := unmarshalDocumentWithXattr(ctx, docid, nil, rawXattr, rawUserXattr, cas, DocUnmarshalSync) if unmarshalErr != nil { return emptySyncData, unmarshalErr } @@ -200,7 +200,7 @@ func (db *DatabaseCollection) GetDocSyncDataNoImport(ctx context.Context, docid var xattrValue []byte if cas, err = db.dataStore.GetXattr(docid, base.SyncXattrName, &xattrValue); err == nil { var doc *Document - doc, err = unmarshalDocumentWithXattr(docid, nil, xattrValue, nil, cas, level) + doc, err = unmarshalDocumentWithXattr(ctx, docid, nil, xattrValue, nil, cas, level) if err == nil { syncData = doc.SyncData } @@ -218,7 +218,7 @@ func (db *DatabaseCollection) GetDocSyncDataNoImport(ctx context.Context, docid // (unmarshaling populates `syncData` since `docRoot` points to it.) if !syncData.HasValidSyncData() { base.InfofCtx(ctx, base.KeyCRUD, "No valid sync data in doc %q; checking for xattrs", base.UD(docid)) - if upgradeDoc, _ := db.checkForUpgrade(docid, level); upgradeDoc != nil { + if upgradeDoc, _ := db.checkForUpgrade(ctx, docid, level); upgradeDoc != nil { // No valid sync data in doc, but doc has been upgraded to use xattrs syncData = upgradeDoc.SyncData } else { @@ -287,10 +287,10 @@ func (db *DatabaseCollectionWithUser) Get1xRevBodyWithHistory(ctx context.Contex requestedHistory = nil } if requestedHistory != nil { - _, requestedHistory = trimEncodedRevisionsToAncestor(requestedHistory, historyFrom, maxHistory) + _, requestedHistory = trimEncodedRevisionsToAncestor(ctx, requestedHistory, historyFrom, maxHistory) } - return rev.Mutable1xBody(db, requestedHistory, attachmentsSince, showExp) + return rev.Mutable1xBody(ctx, db, requestedHistory, attachmentsSince, showExp) } // Underlying revision retrieval used by Get1xRevBody, Get1xRevBodyWithHistory, GetRevCopy. @@ -334,7 +334,7 @@ func (db *DatabaseCollectionWithUser) getRev(ctx context.Context, docid, revid s requestedHistory = nil } if requestedHistory != nil { - _, requestedHistory = trimEncodedRevisionsToAncestor(requestedHistory, historyFrom, maxHistory) + _, requestedHistory = trimEncodedRevisionsToAncestor(ctx, requestedHistory, historyFrom, maxHistory) } isAuthorized, redactedRev := db.authorizeUserForChannels(docid, revision.RevID, revision.Channels, revision.Deleted, requestedHistory) @@ -393,7 +393,7 @@ func (db *DatabaseCollectionWithUser) GetDelta(ctx context.Context, docID, fromR if fromRevision.Delta != nil { if fromRevision.Delta.ToRevID == toRevID { - isAuthorized, redactedBody := db.authorizeUserForChannels(docID, toRevID, fromRevision.Delta.ToChannels, fromRevision.Delta.ToDeleted, encodeRevisions(docID, fromRevision.Delta.RevisionHistory)) + isAuthorized, redactedBody := db.authorizeUserForChannels(docID, toRevID, fromRevision.Delta.ToChannels, fromRevision.Delta.ToDeleted, encodeRevisions(ctx, docID, fromRevision.Delta.RevisionHistory)) if !isAuthorized { return nil, &redactedBody, nil } @@ -728,7 +728,7 @@ func (db *DatabaseCollectionWithUser) get1xRevFromDoc(ctx context.Context, doc * if getHistoryErr != nil { return nil, removed, getHistoryErr } - kvPairs = append(kvPairs, base.KVPair{Key: BodyRevisions, Val: encodeRevisions(doc.ID, validatedHistory)}) + kvPairs = append(kvPairs, base.KVPair{Key: BodyRevisions, Val: encodeRevisions(ctx, doc.ID, validatedHistory)}) } bodyBytes, err = base.InjectJSONProperties(bodyBytes, kvPairs...) @@ -790,7 +790,7 @@ func (db *DatabaseCollectionWithUser) getAvailableRevAttachments(ctx context.Con // Moves a revision's ancestor's body out of the document object and into a separate db doc. func (db *DatabaseCollectionWithUser) backupAncestorRevs(ctx context.Context, doc *Document, newDoc *Document) { - newBodyBytes, err := newDoc.BodyBytes() + newBodyBytes, err := newDoc.BodyBytes(ctx) if err != nil { base.WarnfCtx(ctx, "Error getting body bytes when backing up ancestor revs") return @@ -816,7 +816,7 @@ func (db *DatabaseCollectionWithUser) backupAncestorRevs(ctx context.Context, do if ancestorRevId == doc.CurrentRev { doc.RemoveBody() } else { - doc.removeRevisionBody(ancestorRevId) + doc.removeRevisionBody(ctx, ancestorRevId) } } @@ -826,7 +826,7 @@ func (db *DatabaseCollectionWithUser) OnDemandImportForWrite(ctx context.Context // Check whether the doc requiring import is an SDK delete isDelete := false - if doc.Body() == nil { + if doc.Body(ctx) == nil { isDelete = true } else { isDelete = deleted @@ -855,7 +855,7 @@ func (db *DatabaseCollectionWithUser) Put(ctx context.Context, docid string, bod // Get the revision ID to match, and the new generation number: matchRev, _ := body[BodyRev].(string) - generation, _ := ParseRevID(matchRev) + generation, _ := ParseRevID(ctx, matchRev) if generation < 0 { return "", nil, base.HTTPErrorf(http.StatusBadRequest, "Invalid revision ID") } @@ -922,7 +922,7 @@ func (db *DatabaseCollectionWithUser) Put(ctx context.Context, docid string, bod if !doc.History[matchRev].Deleted { conflictErr = base.HTTPErrorf(http.StatusConflict, "Document exists") } else { - generation, _ = ParseRevID(matchRev) + generation, _ = ParseRevID(ctx, matchRev) generation++ } } @@ -1006,7 +1006,7 @@ func (db *DatabaseCollectionWithUser) PutExistingRev(ctx context.Context, newDoc // 3. If noConflicts == true and a conflictResolverFunc is provided, conflicts will be resolved and the result added to the document. func (db *DatabaseCollectionWithUser) PutExistingRevWithConflictResolution(ctx context.Context, newDoc *Document, docHistory []string, noConflicts bool, conflictResolver *ConflictResolver, forceAllowConflictingTombstone bool, existingDoc *sgbucket.BucketDocument) (doc *Document, newRevID string, err error) { newRev := docHistory[0] - generation, _ := ParseRevID(newRev) + generation, _ := ParseRevID(ctx, newRev) if generation < 0 { return nil, "", base.HTTPErrorf(http.StatusBadRequest, "Invalid revision ID") } @@ -1192,7 +1192,7 @@ func (db *DatabaseCollectionWithUser) resolveConflict(ctx context.Context, local RemoteDocument: remoteDocBody, } - resolvedBody, resolutionType, resolveFuncError := resolver.Resolve(conflict) + resolvedBody, resolutionType, resolveFuncError := resolver.Resolve(ctx, conflict) if resolveFuncError != nil { base.InfofCtx(ctx, base.KeyReplicate, "Error when running conflict resolution for doc %s: %v", base.UD(localDoc.ID), resolveFuncError) return "", nil, resolveFuncError @@ -1243,13 +1243,13 @@ func (db *DatabaseCollectionWithUser) resolveDocRemoteWins(ctx context.Context, func (db *DatabaseCollectionWithUser) resolveDocLocalWins(ctx context.Context, localDoc *Document, remoteDoc *Document, conflict Conflict, docHistory []string) (resolvedRevID string, updatedHistory []string, err error) { // Clone the local revision as a child of the remote revision - docBodyBytes, err := localDoc.BodyBytes() + docBodyBytes, err := localDoc.BodyBytes(ctx) if err != nil { return "", nil, fmt.Errorf("Unable to retrieve local document body while resolving conflict: %w", err) } remoteRevID := remoteDoc.RevID - remoteGeneration, _ := ParseRevID(remoteRevID) + remoteGeneration, _ := ParseRevID(ctx, remoteRevID) var newRevID string if !localDoc.Deleted { @@ -1260,7 +1260,7 @@ func (db *DatabaseCollectionWithUser) resolveDocLocalWins(ctx context.Context, l // and need to ensure the remote branch is the winning branch. To do that, we inject entries into the remote // branch's history until it's generation is longer than the local branch. remoteDoc.Deleted = localDoc.Deleted - localGeneration, _ := ParseRevID(localDoc.CurrentRev) + localGeneration, _ := ParseRevID(ctx, localDoc.CurrentRev) requiredAdditionalRevs := localGeneration - remoteGeneration injectedRevBody := []byte("{}") @@ -1291,9 +1291,9 @@ func (db *DatabaseCollectionWithUser) resolveDocLocalWins(ctx context.Context, l commonAncestorRevID := localDoc.SyncData.History.findAncestorFromSet(localDoc.CurrentRev, docHistory) commonAncestorGen := 0 if commonAncestorRevID != "" { - commonAncestorGen, _ = ParseRevID(commonAncestorRevID) + commonAncestorGen, _ = ParseRevID(ctx, commonAncestorRevID) } - newRevIDGen, _ := ParseRevID(newRevID) + newRevIDGen, _ := ParseRevID(ctx, newRevID) // If attachment revpos is older than common ancestor, or common ancestor doesn't exist, set attachment's // revpos to the generation of newRevID (i.e. treat as previously unknown to this revtree branch) @@ -1347,7 +1347,7 @@ func (db *DatabaseCollectionWithUser) resolveDocMerge(ctx context.Context, local } remoteRevID := remoteDoc.RevID - remoteGeneration, _ := ParseRevID(remoteRevID) + remoteGeneration, _ := ParseRevID(ctx, remoteRevID) mergedRevID, err := CreateRevID(remoteGeneration+1, remoteRevID, mergedBody) if err != nil { return "", nil, err @@ -1379,7 +1379,7 @@ func (db *DatabaseCollectionWithUser) tombstoneActiveRevision(ctx context.Contex } // Create tombstone - newGeneration := genOfRevID(revID) + 1 + newGeneration := genOfRevID(ctx, revID) + 1 newRevID := CreateRevIDWithBytes(newGeneration, revID, []byte(DeletedDocument)) err = doc.History.addRevision(doc.ID, RevInfo{ @@ -1392,7 +1392,7 @@ func (db *DatabaseCollectionWithUser) tombstoneActiveRevision(ctx context.Contex } // Backup previous revision body, then remove the current body from the doc - bodyBytes, err := doc.BodyBytes() + bodyBytes, err := doc.BodyBytes(ctx) if err == nil { _ = db.setOldRevisionJSON(ctx, doc.ID, revID, bodyBytes, db.oldRevExpirySeconds()) } @@ -1401,9 +1401,9 @@ func (db *DatabaseCollectionWithUser) tombstoneActiveRevision(ctx context.Contex return newRevID, nil } -func (doc *Document) updateWinningRevAndSetDocFlags() { +func (doc *Document) updateWinningRevAndSetDocFlags(ctx context.Context) { var branched, inConflict bool - doc.CurrentRev, branched, inConflict = doc.History.winningRevision() + doc.CurrentRev, branched, inConflict = doc.History.winningRevision(ctx) doc.setFlag(channels.Deleted, doc.History[doc.CurrentRev].Deleted) doc.setFlag(channels.Conflict, inConflict) doc.setFlag(channels.Branched, branched) @@ -1417,7 +1417,7 @@ func (doc *Document) updateWinningRevAndSetDocFlags() { func (db *DatabaseCollectionWithUser) storeOldBodyInRevTreeAndUpdateCurrent(ctx context.Context, doc *Document, prevCurrentRev string, newRevID string, newDoc *Document, newDocHasAttachments bool) { if doc.HasBody() && doc.CurrentRev != prevCurrentRev && prevCurrentRev != "" { // Store the doc's previous body into the revision tree: - oldBodyJson, marshalErr := doc.BodyBytes() + oldBodyJson, marshalErr := doc.BodyBytes(ctx) if marshalErr != nil { base.WarnfCtx(ctx, "Unable to marshal document body for storage in rev tree: %v", marshalErr) } @@ -1428,7 +1428,7 @@ func (db *DatabaseCollectionWithUser) storeOldBodyInRevTreeAndUpdateCurrent(ctx // Stamp _attachments into the old body we're about to backup // We need to do a revpos check here because doc actually contains the new attachments if len(doc.SyncData.Attachments) > 0 { - prevCurrentRevGen, _ := ParseRevID(prevCurrentRev) + prevCurrentRevGen, _ := ParseRevID(ctx, prevCurrentRev) bodyAtts := make(AttachmentsMeta) for attName, attMeta := range doc.SyncData.Attachments { if attMetaMap, ok := attMeta.(map[string]interface{}); ok { @@ -1465,7 +1465,7 @@ func (db *DatabaseCollectionWithUser) storeOldBodyInRevTreeAndUpdateCurrent(ctx doc.setNonWinningRevisionBody(prevCurrentRev, oldBodyJson, db.AllowExternalRevBodyStorage(), oldDocHasAttachments) } // Store the new revision body into the doc: - doc.setRevisionBody(newRevID, newDoc, db.AllowExternalRevBodyStorage(), newDocHasAttachments) + doc.setRevisionBody(ctx, newRevID, newDoc, db.AllowExternalRevBodyStorage(), newDocHasAttachments) doc.SyncData.Attachments = newDoc.DocAttachments if doc.CurrentRev == newRevID { @@ -1475,7 +1475,7 @@ func (db *DatabaseCollectionWithUser) storeOldBodyInRevTreeAndUpdateCurrent(ctx doc.NewestRev = newRevID doc.setFlag(channels.Hidden, true) if doc.CurrentRev != prevCurrentRev { - doc.promoteNonWinningRevisionBody(doc.CurrentRev, db.RevisionBodyLoader) + doc.promoteNonWinningRevisionBody(ctx, doc.CurrentRev, db.RevisionBodyLoader) } } } @@ -1589,7 +1589,7 @@ func (db *DatabaseContext) assignSequence(ctx context.Context, docSequence uint6 for { var err error - if docSequence, err = db.sequences.nextSequence(); err != nil { + if docSequence, err = db.sequences.nextSequence(ctx); err != nil { return unusedSequences, err } @@ -1736,7 +1736,7 @@ func (col *DatabaseCollectionWithUser) documentUpdateFunc(ctx context.Context, d } prevCurrentRev := doc.CurrentRev - doc.updateWinningRevAndSetDocFlags() + doc.updateWinningRevAndSetDocFlags(ctx) newDocHasAttachments := len(newAttachments) > 0 col.storeOldBodyInRevTreeAndUpdateCurrent(ctx, doc, prevCurrentRev, newRevID, newDoc, newDocHasAttachments) @@ -1781,8 +1781,8 @@ func (col *DatabaseCollectionWithUser) documentUpdateFunc(ctx context.Context, d if err != nil { return } - changedAccessPrincipals = doc.Access.updateAccess(doc, access) - changedRoleAccessUsers = doc.RoleAccess.updateAccess(doc, roles) + changedAccessPrincipals = doc.Access.updateAccess(ctx, doc, access) + changedRoleAccessUsers = doc.RoleAccess.updateAccess(ctx, doc, roles) } else { base.DebugfCtx(ctx, base.KeyCRUD, "updateDoc(%q): Rev %q leaves %q still current", @@ -1790,7 +1790,7 @@ func (col *DatabaseCollectionWithUser) documentUpdateFunc(ctx context.Context, d } // Prune old revision history to limit the number of revisions: - if pruned := doc.pruneRevisions(col.revsLimit(), doc.CurrentRev); pruned > 0 { + if pruned := doc.pruneRevisions(ctx, col.revsLimit(), doc.CurrentRev); pruned > 0 { base.DebugfCtx(ctx, base.KeyCRUD, "updateDoc(%q): Pruned %d old revisions", base.UD(doc.ID), pruned) } @@ -1865,7 +1865,7 @@ func (db *DatabaseCollectionWithUser) updateAndReturnDoc(ctx context.Context, do // If we can't find sync metadata in the document body, check for upgrade. If upgrade, retry write using WriteUpdateWithXattr if err != nil && err.Error() == "409 Not imported" { - _, bucketDocument := db.checkForUpgrade(key, DocUnmarshalAll) + _, bucketDocument := db.checkForUpgrade(ctx, key, DocUnmarshalAll) if bucketDocument != nil && bucketDocument.Xattr != nil { existingDoc = bucketDocument upgradeInProgress = true @@ -1878,7 +1878,7 @@ func (db *DatabaseCollectionWithUser) updateAndReturnDoc(ctx context.Context, do // Update the document, storing metadata in extended attribute casOut, err = db.dataStore.WriteUpdateWithXattr(key, base.SyncXattrName, db.userXattrKey(), expiry, opts, existingDoc, func(currentValue []byte, currentXattr []byte, currentUserXattr []byte, cas uint64) (raw []byte, rawXattr []byte, deleteDoc bool, syncFuncExpiry *uint32, err error) { // Be careful: this block can be invoked multiple times if there are races! - if doc, err = unmarshalDocumentWithXattr(docid, currentValue, currentXattr, currentUserXattr, cas, DocUnmarshalAll); err != nil { + if doc, err = unmarshalDocumentWithXattr(ctx, docid, currentValue, currentXattr, currentUserXattr, cas, DocUnmarshalAll); err != nil { return } prevCurrentRev = doc.CurrentRev @@ -1987,7 +1987,7 @@ func (db *DatabaseCollectionWithUser) updateAndReturnDoc(ctx context.Context, do } // Lazily marshal bytes for storage in revcache - storedDocBytes, err := storedDoc.BodyBytes() + storedDocBytes, err := storedDoc.BodyBytes(ctx) if err != nil { return nil, "", err } @@ -1997,12 +1997,12 @@ func (db *DatabaseCollectionWithUser) updateAndReturnDoc(ctx context.Context, do DocID: docid, RevID: newRevID, BodyBytes: storedDocBytes, - History: encodeRevisions(docid, history), + History: encodeRevisions(ctx, docid, history), Channels: revChannels, Attachments: doc.Attachments, Expiry: doc.Expiry, Deleted: doc.History[newRevID].Deleted, - _shallowCopyBody: storedDoc.Body(), + _shallowCopyBody: storedDoc.Body(ctx), } if createNewRevIDSkipped { @@ -2012,12 +2012,12 @@ func (db *DatabaseCollectionWithUser) updateAndReturnDoc(ctx context.Context, do } if db.eventMgr().HasHandlerForEvent(DocumentChange) { - webhookJSON, err := doc.BodyWithSpecialProperties() + webhookJSON, err := doc.BodyWithSpecialProperties(ctx) if err != nil { base.WarnfCtx(ctx, "Error marshalling doc with id %s and revid %s for webhook post: %v", base.UD(docid), base.UD(newRevID), err) } else { winningRevChange := prevCurrentRev != doc.CurrentRev - err = db.eventMgr().RaiseDocumentChangeEvent(webhookJSON, docid, oldBodyJSON, revChannels, winningRevChange) + err = db.eventMgr().RaiseDocumentChangeEvent(ctx, webhookJSON, docid, oldBodyJSON, revChannels, winningRevChange) if err != nil { base.DebugfCtx(ctx, base.KeyCRUD, "Error raising document change event: %v", err) } @@ -2058,7 +2058,7 @@ func (db *DatabaseCollectionWithUser) updateAndReturnDoc(ctx context.Context, do } // Remove any obsolete non-winning revision bodies - doc.deleteRemovedRevisionBodies(db.dataStore) + doc.deleteRemovedRevisionBodies(ctx, db.dataStore) // Mark affected users/roles as needing to recompute their channel access: db.MarkPrincipalsChanged(ctx, docid, newRevID, changedAccessPrincipals, changedRoleAccessUsers, doc.Sequence) @@ -2306,7 +2306,7 @@ func (col *DatabaseCollectionWithUser) getChannelsAndAccess(ctx context.Context, col.dbStats().Security().NumAccessErrors.Add(1) col.collectionStats.SyncFunctionRejectAccessCount.Add(1) } - } else if !validateAccessMap(access) || !validateRoleAccessMap(roles) { + } else if !validateAccessMap(ctx, access) || !validateRoleAccessMap(ctx, roles) { err = base.HTTPErrorf(500, "Error in JS sync function") } @@ -2352,25 +2352,25 @@ func MakeUserCtx(user auth.User, scopeName string, collectionName string) map[st } // Are the principal and role names in an AccessMap all valid? -func validateAccessMap(access channels.AccessMap) bool { +func validateAccessMap(ctx context.Context, access channels.AccessMap) bool { for name := range access { principalName, _ := channels.AccessNameToPrincipalName(name) if !auth.IsValidPrincipalName(principalName) { - base.WarnfCtx(context.Background(), "Invalid principal name %q in access() or role() call", base.UD(principalName)) + base.WarnfCtx(ctx, "Invalid principal name %q in access() or role() call", base.UD(principalName)) return false } } return true } -func validateRoleAccessMap(roleAccess channels.AccessMap) bool { - if !validateAccessMap(roleAccess) { +func validateRoleAccessMap(ctx context.Context, roleAccess channels.AccessMap) bool { + if !validateAccessMap(ctx, roleAccess) { return false } for _, roles := range roleAccess { for rolename := range roles { if !auth.IsValidPrincipalName(rolename) { - base.WarnfCtx(context.Background(), "Invalid role name %q in role() call", base.UD(rolename)) + base.WarnfCtx(ctx, "Invalid role name %q in role() call", base.UD(rolename)) return false } } @@ -2453,13 +2453,13 @@ func (c *DatabaseCollection) ComputeRolesForUser(ctx context.Context, user auth. } // Checks whether a document has a mobile xattr. Used when running in non-xattr mode to support no downtime upgrade. -func (c *DatabaseCollection) checkForUpgrade(key string, unmarshalLevel DocumentUnmarshalLevel) (*Document, *sgbucket.BucketDocument) { +func (c *DatabaseCollection) checkForUpgrade(ctx context.Context, key string, unmarshalLevel DocumentUnmarshalLevel) (*Document, *sgbucket.BucketDocument) { // If we are using xattrs or Couchbase Server doesn't support them, an upgrade isn't going to be in progress if c.UseXattrs() || !c.dataStore.IsSupported(sgbucket.BucketStoreFeatureXattrs) { return nil, nil } - doc, rawDocument, err := c.GetDocWithXattr(key, unmarshalLevel) + doc, rawDocument, err := c.GetDocWithXattr(ctx, key, unmarshalLevel) if err != nil || doc == nil || !doc.HasValidSyncData() { return nil, nil } @@ -2490,10 +2490,10 @@ func (db *DatabaseCollectionWithUser) RevDiff(ctx context.Context, docid string, if !doc.History.contains(revid) { missing = append(missing, revid) // Look at the doc's leaves for a known possible ancestor: - if gen, _ := ParseRevID(revid); gen > 1 { + if gen, _ := ParseRevID(ctx, revid); gen > 1 { doc.History.forEachLeaf(func(possible *RevInfo) { if !revidsSet.Contains(possible.ID) { - possibleGen, _ := ParseRevID(possible.ID) + possibleGen, _ := ParseRevID(ctx, possible.ID) if possibleGen < gen && possibleGen >= gen-100 { possibleSet[possible.ID] = true } else if possibleGen == gen && possible.Parent != "" { diff --git a/db/crud_test.go b/db/crud_test.go index af57b55ee5..57d338611e 100644 --- a/db/crud_test.go +++ b/db/crud_test.go @@ -238,7 +238,7 @@ func TestHasAttachmentsFlagForLegacyAttachments(t *testing.T) { require.NoError(t, err) // Get the existing bucket doc - _, existingBucketDoc, err := collection.GetDocWithXattr(docID, DocUnmarshalAll) + _, existingBucketDoc, err := collection.GetDocWithXattr(ctx, docID, DocUnmarshalAll) require.NoError(t, err) // Migrate document metadata from document body to system xattr. @@ -1159,7 +1159,7 @@ func BenchmarkHandleRevDelta(b *testing.B) { } deltaSrcMap := map[string]interface{}(deltaSrcBody) - _ = base.Patch(&deltaSrcMap, newDoc.Body()) + _ = base.Patch(&deltaSrcMap, newDoc.Body(ctx)) } b.Run("SmallDiff", func(b *testing.B) { diff --git a/db/database.go b/db/database.go index 3102cfa51d..3b61b8ea64 100644 --- a/db/database.go +++ b/db/database.go @@ -345,7 +345,7 @@ func ConnectToBucket(ctx context.Context, spec base.BucketSpec, failFast bool) ( } description := fmt.Sprintf("Attempt to connect to bucket : %v", spec.BucketName) - err, ibucket := base.RetryLoop(description, worker, getNewDatabaseSleeperFunc()) + err, ibucket := base.RetryLoop(ctx, description, worker, getNewDatabaseSleeperFunc()) if err != nil { return nil, err } @@ -361,11 +361,11 @@ func getServerUUID(ctx context.Context, bucket base.Bucket) (string, error) { } // start a retry loop to get server ID worker := func() (bool, error, interface{}) { - uuid, err := base.GetServerUUID(gocbV2Bucket) + uuid, err := base.GetServerUUID(ctx, gocbV2Bucket) return err != nil, err, uuid } - err, uuid := base.RetryLoopCtx("Getting ServerUUID", worker, getNewDatabaseSleeperFunc(), ctx) + err, uuid := base.RetryLoop(ctx, "Getting ServerUUID", worker, getNewDatabaseSleeperFunc()) return uuid.(string), err } @@ -393,7 +393,7 @@ func NewDatabaseContext(ctx context.Context, dbName string, bucket base.Bucket, // in order to pass it to RegisterImportPindexImpl ctx = base.DatabaseLogCtx(ctx, dbName, options.LoggingConfig.Console) - if err := base.RequireNoBucketTTL(bucket); err != nil { + if err := base.RequireNoBucketTTL(ctx, bucket); err != nil { return nil, err } @@ -476,7 +476,7 @@ func NewDatabaseContext(ctx context.Context, dbName string, bucket base.Bucket, // Initialize sg cluster config. Required even if import and sgreplicate are disabled // on this node, to support replication REST API calls if base.IsEnterpriseEdition() { - sgCfg, err := base.NewCfgSG(metadataStore, metaKeys.SGCfgPrefix(dbContext.Options.GroupID)) + sgCfg, err := base.NewCfgSG(ctx, metadataStore, metaKeys.SGCfgPrefix(dbContext.Options.GroupID)) if err != nil { return nil, err } @@ -600,12 +600,12 @@ func (context *DatabaseContext) Close(ctx context.Context) { waitForBGTCompletion(ctx, BGTCompletionMaxWait, context.backgroundTasks, context.Name) context.sequences.Stop(ctx) context.mutationListener.Stop(ctx) - context.changeCache.Stop() + context.changeCache.Stop(ctx) // Stop the channel cache and it's background tasks. - context.channelCache.Stop() + context.channelCache.Stop(ctx) context.ImportListener.Stop() if context.Heartbeater != nil { - context.Heartbeater.Stop() + context.Heartbeater.Stop(ctx) } if context.SGReplicateMgr != nil { context.SGReplicateMgr.Stop() @@ -733,13 +733,13 @@ func (context *DatabaseContext) RestartListener(ctx context.Context) error { } // Removes previous versions of Sync Gateway's design docs found on the server -func (dbCtx *DatabaseContext) RemoveObsoleteDesignDocs(previewOnly bool) (removedDesignDocs []string, err error) { +func (dbCtx *DatabaseContext) RemoveObsoleteDesignDocs(ctx context.Context, previewOnly bool) (removedDesignDocs []string, err error) { ds := dbCtx.Bucket.DefaultDataStore() viewStore, ok := ds.(sgbucket.ViewStore) if !ok { return []string{}, fmt.Errorf("Datastore does not support views") } - return removeObsoleteDesignDocs(context.TODO(), viewStore, previewOnly, dbCtx.UseViews()) + return removeObsoleteDesignDocs(ctx, viewStore, previewOnly, dbCtx.UseViews()) } // getDataStores returns all datastores on the database, including metadatastore @@ -801,8 +801,8 @@ func (dbCtx *DatabaseContext) RemoveObsoleteIndexes(ctx context.Context, preview // TODO: The underlying code (NotifyCheckForTermination) doesn't actually leverage the specific username - should be refactored // // to remove -func (context *DatabaseContext) NotifyTerminatedChanges(username string) { - context.mutationListener.NotifyCheckForTermination(base.SetOf(base.UserPrefixRoot + username)) +func (context *DatabaseContext) NotifyTerminatedChanges(ctx context.Context, username string) { + context.mutationListener.NotifyCheckForTermination(ctx, base.SetOf(base.UserPrefixRoot+username)) } func (dc *DatabaseContext) TakeDbOffline(ctx context.Context, reason string) error { @@ -815,12 +815,12 @@ func (dc *DatabaseContext) TakeDbOffline(ctx context.Context, reason string) err dc.AccessLock.Lock() defer dc.AccessLock.Unlock() - dc.changeCache.Stop() + dc.changeCache.Stop(ctx) // set DB state to Offline atomic.StoreUint32(&dc.State, DBOffline) - if err := dc.EventMgr.RaiseDBStateChangeEvent(dc.Name, "offline", reason, dc.Options.AdminInterface); err != nil { + if err := dc.EventMgr.RaiseDBStateChangeEvent(ctx, dc.Name, "offline", reason, dc.Options.AdminInterface); err != nil { base.DebugfCtx(ctx, base.KeyCRUD, "Error raising database state change event: %v", err) } @@ -1498,7 +1498,7 @@ func (db *Database) Compact(ctx context.Context, skipRunningStateCheck bool, cal count := len(purgedDocs) purgedDocCount += count if count > 0 { - collection.RemoveFromChangeCache(purgedDocs, startTime) + collection.RemoveFromChangeCache(ctx, purgedDocs, startTime) collection.dbStats().Database().NumTombstonesCompacted.Add(int64(count)) } base.DebugfCtx(ctx, base.KeyAll, "Compacted %v tombstones", count) @@ -1531,7 +1531,7 @@ func (db *DatabaseContext) GetMetadataPurgeInterval(ctx context.Context) time.Du if !ok { return DefaultPurgeInterval } - serverPurgeInterval, err := cbStore.MetadataPurgeInterval() + serverPurgeInterval, err := cbStore.MetadataPurgeInterval(ctx) if err != nil { base.WarnfCtx(ctx, "Unable to retrieve server's metadata purge interval - using default purge interval %.2f days. %s", DefaultPurgeInterval.Hours()/24, err) } @@ -1673,7 +1673,7 @@ func (c *DatabaseCollection) updateAllPrincipalsSequences(ctx context.Context) e if err != nil { return err } - err = c.regeneratePrincipalSequences(authr, role) + err = c.regeneratePrincipalSequences(ctx, authr, role) if err != nil { return err } @@ -1684,7 +1684,7 @@ func (c *DatabaseCollection) updateAllPrincipalsSequences(ctx context.Context) e if err != nil { return err } - err = c.regeneratePrincipalSequences(authr, user) + err = c.regeneratePrincipalSequences(ctx, authr, user) if err != nil { return err } @@ -1692,8 +1692,8 @@ func (c *DatabaseCollection) updateAllPrincipalsSequences(ctx context.Context) e return nil } -func (c *DatabaseCollection) regeneratePrincipalSequences(authr *auth.Authenticator, princ auth.Principal) error { - nextSeq, err := c.sequences().nextSequence() +func (c *DatabaseCollection) regeneratePrincipalSequences(ctx context.Context, authr *auth.Authenticator, princ auth.Principal) error { + nextSeq, err := c.sequences().nextSequence(ctx) if err != nil { return err } @@ -1760,8 +1760,8 @@ func (db *DatabaseCollectionWithUser) getResyncedDocument(ctx context.Context, d } changedChannels, err := doc.updateChannels(ctx, channels) - changed = len(doc.Access.updateAccess(doc, access)) + - len(doc.RoleAccess.updateAccess(doc, roles)) + + changed = len(doc.Access.updateAccess(ctx, doc, access)) + + len(doc.RoleAccess.updateAccess(ctx, doc, roles)) + len(changedChannels) if err != nil { return @@ -1789,7 +1789,7 @@ func (db *DatabaseCollectionWithUser) resyncDocument(ctx context.Context, docid, if currentValue == nil || len(currentValue) == 0 { return nil, nil, deleteDoc, nil, base.ErrUpdateCancel } - doc, err := unmarshalDocumentWithXattr(docid, currentValue, currentXattr, currentUserXattr, cas, DocUnmarshalAll) + doc, err := unmarshalDocumentWithXattr(ctx, docid, currentValue, currentXattr, currentUserXattr, cas, DocUnmarshalAll) if err != nil { return nil, nil, deleteDoc, nil, err } @@ -2030,8 +2030,8 @@ func (context *DatabaseContext) AllowFlushNonCouchbaseBuckets() bool { // ////// SEQUENCE ALLOCATION: -func (context *DatabaseContext) LastSequence() (uint64, error) { - return context.sequences.lastSequence() +func (context *DatabaseContext) LastSequence(ctx context.Context) (uint64, error) { + return context.sequences.lastSequence(ctx) } // Helpers for unsupported options @@ -2141,8 +2141,8 @@ func (dbc *Database) GetDefaultDatabaseCollectionWithUser() (*DatabaseCollection }, nil } -func (dbc *DatabaseContext) AuthenticatorOptions() auth.AuthenticatorOptions { - defaultOptions := auth.DefaultAuthenticatorOptions() +func (dbc *DatabaseContext) AuthenticatorOptions(ctx context.Context) auth.AuthenticatorOptions { + defaultOptions := auth.DefaultAuthenticatorOptions(ctx) defaultOptions.MetaKeys = dbc.MetadataKeys return defaultOptions } @@ -2222,14 +2222,14 @@ func (db *DatabaseContext) StartOnlineProcesses(ctx context.Context) (returnedEr if err != nil { return pkgerrors.Wrapf(err, "Error starting heartbeater for bucket %s", base.MD(db.Bucket.GetName()).Redact()) } - err = heartbeater.StartSendingHeartbeats() + err = heartbeater.StartSendingHeartbeats(ctx) if err != nil { return err } db.Heartbeater = heartbeater cleanupFunctions = append(cleanupFunctions, func() { - db.Heartbeater.Stop() + db.Heartbeater.Stop(ctx) }) } @@ -2258,7 +2258,7 @@ func (db *DatabaseContext) StartOnlineProcesses(ctx context.Context) (returnedEr }) // Get current value of _sync:seq - initialSequence, seqErr := db.sequences.lastSequence() + initialSequence, seqErr := db.sequences.lastSequence(ctx) if seqErr != nil { return seqErr } @@ -2267,14 +2267,14 @@ func (db *DatabaseContext) StartOnlineProcesses(ctx context.Context) (returnedEr // Unlock change cache. Validate that any allocated sequences on other nodes have either been assigned or released // before starting if initialSequence > 0 { - _ = db.sequences.waitForReleasedSequences(initialSequenceTime) + _ = db.sequences.waitForReleasedSequences(ctx, initialSequenceTime) } if err := db.changeCache.Start(initialSequence); err != nil { return err } cleanupFunctions = append(cleanupFunctions, func() { - db.changeCache.Stop() + db.changeCache.Stop(ctx) }) // If this is an xattr import node, start import feed. Must be started after the caching DCP feed, as import cfg @@ -2328,7 +2328,7 @@ func (db *DatabaseContext) StartOnlineProcesses(ctx context.Context) (returnedEr db.LocalJWTProviders = make(auth.LocalJWTProviderMap, len(db.Options.LocalJWTConfig)) for name, cfg := range db.Options.LocalJWTConfig { - db.LocalJWTProviders[name] = cfg.BuildProvider(name) + db.LocalJWTProviders[name] = cfg.BuildProvider(ctx, name) } if db.UseXattrs() { @@ -2372,7 +2372,7 @@ func (db *DatabaseContext) StartOnlineProcesses(ctx context.Context) (returnedEr } db.backgroundTasks = append(db.backgroundTasks, bgtSyncTime) - if err := base.RequireNoBucketTTL(db.Bucket); err != nil { + if err := base.RequireNoBucketTTL(ctx, db.Bucket); err != nil { return err } @@ -2381,7 +2381,7 @@ func (db *DatabaseContext) StartOnlineProcesses(ctx context.Context) (returnedEr // Start checking heartbeats for other nodes. Must be done after caching feed starts, to ensure any removals // are detected and processed by this node. if db.Heartbeater != nil { - if err := db.Heartbeater.StartCheckingHeartbeats(); err != nil { + if err := db.Heartbeater.StartCheckingHeartbeats(ctx); err != nil { return err } // No cleanup necessary, stop heartbeater above will take care of it @@ -2432,7 +2432,7 @@ func (dbc *DatabaseContext) InstallPrincipals(ctx context.Context, spec map[stri } - err, _ := base.RetryLoop("installPrincipals", worker, base.CreateDoublingSleeperFunc(16, 10)) + err, _ := base.RetryLoop(ctx, "installPrincipals", worker, base.CreateDoublingSleeperFunc(16, 10)) if err != nil { return err } diff --git a/db/database_collection.go b/db/database_collection.go index 10e72ea95e..ffab1ae03e 100644 --- a/db/database_collection.go +++ b/db/database_collection.go @@ -158,7 +158,7 @@ func (c *DatabaseCollection) groupID() string { // FlushChannelCache flush support. Currently test-only - added for unit test access from rest package func (c *DatabaseCollection) FlushChannelCache(ctx context.Context) error { base.InfofCtx(ctx, base.KeyCache, "Flushing channel cache") - return c.dbCtx.changeCache.Clear() + return c.dbCtx.changeCache.Clear(ctx) } // FlushRevisionCacheForTest creates a new revision cache. This is currently at the database level. Only use this in test code. @@ -197,8 +197,8 @@ func (c *DatabaseCollection) isGuestReadOnly() bool { } // LastSequence returns the highest sequence number allocated for this collection. -func (c *DatabaseCollection) LastSequence() (uint64, error) { - return c.dbCtx.sequences.lastSequence() +func (c *DatabaseCollection) LastSequence(ctx context.Context) (uint64, error) { + return c.dbCtx.sequences.lastSequence(ctx) } // localDocExpirySecs returns the expiry for docs tracking Couchbase Lite replication state. This is controlled at the database level. @@ -238,8 +238,8 @@ func (c *DatabaseCollectionWithUser) ReloadUser(ctx context.Context) error { } // RemoveFromChangeCache removes select documents from all channel caches and returns the number of documents removed. -func (c *DatabaseCollection) RemoveFromChangeCache(docIDs []string, startTime time.Time) int { - return c.dbCtx.changeCache.Remove(c.GetCollectionID(), docIDs, startTime) +func (c *DatabaseCollection) RemoveFromChangeCache(ctx context.Context, docIDs []string, startTime time.Time) int { + return c.dbCtx.changeCache.Remove(ctx, c.GetCollectionID(), docIDs, startTime) } // revsLimit is the max depth a document's revision tree can grow to. This is controlled at a database level. diff --git a/db/database_stats.go b/db/database_stats.go index 3cbcc2cc76..3008e02bdb 100644 --- a/db/database_stats.go +++ b/db/database_stats.go @@ -10,6 +10,8 @@ licenses/APL2.txt. package db +import "context" + // Wrapper around *expvars.Map for database stats that provide: // // - A lazy loading mechanism @@ -20,11 +22,11 @@ package db // } // Update database-specific stats that are more efficiently calculated at stats collection time -func (db *DatabaseContext) UpdateCalculatedStats() { +func (db *DatabaseContext) UpdateCalculatedStats(ctx context.Context) { - db.changeCache.updateStats() + db.changeCache.updateStats(ctx) channelCache := db.changeCache.getChannelCache() - db.DbStats.Cache().ChannelCacheMaxEntries.Set(int64(channelCache.MaxCacheSize())) + db.DbStats.Cache().ChannelCacheMaxEntries.Set(int64(channelCache.MaxCacheSize(ctx))) db.DbStats.Cache().HighSeqCached.Set(int64(channelCache.GetHighCacheSequence())) } diff --git a/db/database_test.go b/db/database_test.go index 560ee8a8a4..9cff4a273f 100644 --- a/db/database_test.go +++ b/db/database_test.go @@ -336,7 +336,7 @@ func TestGetDeleted(t *testing.T) { assert.Equal(t, rev2id, doc.SyncData.CurrentRev) // Try again but with a user who doesn't have access to this revision (see #179) - authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions()) + authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions(ctx)) collection.user, err = authenticator.GetUser("") assert.NoError(t, err, "GetUser") collection.user.SetExplicitChannels(nil, 1) @@ -403,7 +403,7 @@ func TestGetRemovedAsUser(t *testing.T) { assert.NoError(t, err, "Purge old revision JSON") // Try again with a user who doesn't have access to this revision - authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions()) + authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions(ctx)) collection.user, err = authenticator.GetUser("") assert.NoError(t, err, "GetUser") @@ -505,8 +505,8 @@ func TestGetRemovalMultiChannel(t *testing.T) { body, err := collection.Get1xRevBody(ctx, "doc1", rev2ID, true, nil) require.NoError(t, err, "Error getting 1x rev body") - _, rev1Digest := ParseRevID(rev1ID) - _, rev2Digest := ParseRevID(rev2ID) + _, rev1Digest := ParseRevID(ctx, rev1ID) + _, rev2Digest := ParseRevID(ctx, rev2ID) bodyExpected := Body{ "k2": "v2", @@ -871,7 +871,7 @@ func TestAllDocsOnly(t *testing.T) { collectionID := collection.GetCollectionID() // Trigger creation of the channel cache for channel "all" - _, err := db.changeCache.getChannelCache().getSingleChannelCache(channels.NewID("all", collectionID)) + _, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, channels.NewID("all", collectionID)) require.NoError(t, err) ids := make([]AllDocsEntry, 100) @@ -916,14 +916,14 @@ func TestAllDocsOnly(t *testing.T) { err = db.changeCache.waitForSequence(ctx, 101, base.DefaultWaitForSequence) require.NoError(t, err) - changeLog, err := collection.GetChangeLog(channels.NewID("all", collectionID), 0) + changeLog, err := collection.GetChangeLog(ctx, channels.NewID("all", collectionID), 0) require.NoError(t, err) require.Len(t, changeLog, 50) assert.Equal(t, "alldoc-51", changeLog[0].DocID) // Now check the changes feed: var options ChangesOptions - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(base.TestCtx(t)) options.ChangesCtx = changesCtx defer changesCtxCancel() changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "all"), options) @@ -994,7 +994,7 @@ func TestUpdatePrincipal(t *testing.T) { _, err = db.UpdatePrincipal(ctx, userInfo, true, true) assert.NoError(t, err, "Unable to update principal") - nextSeq, err := db.sequences.nextSequence() + nextSeq, err := db.sequences.nextSequence(ctx) require.NoError(t, err) assert.Equal(t, uint64(1), nextSeq) @@ -1005,7 +1005,7 @@ func TestUpdatePrincipal(t *testing.T) { _, err = db.UpdatePrincipal(ctx, userInfo, true, true) assert.NoError(t, err, "Unable to update principal") - nextSeq, err = db.sequences.nextSequence() + nextSeq, err = db.sequences.nextSequence(ctx) require.NoError(t, err) assert.Equal(t, uint64(3), nextSeq) } @@ -1034,7 +1034,7 @@ func TestRepeatedConflict(t *testing.T) { assert.NoError(t, err, "add 2-a") // Get the _rev that was set in the body by PutExistingRevWithBody() and make assertions on it - revGen, _ := ParseRevID(newRev) + revGen, _ := ParseRevID(ctx, newRev) assert.Equal(t, 2, revGen) // Remove the _rev key from the body, and call PutExistingRevWithBody() again, which should re-add it @@ -1043,7 +1043,7 @@ func TestRepeatedConflict(t *testing.T) { assert.NoError(t, err) // The _rev should pass the same assertions as before, since PutExistingRevWithBody() should re-add it - revGen, _ = ParseRevID(newRev) + revGen, _ = ParseRevID(ctx, newRev) assert.Equal(t, 2, revGen) } @@ -1060,7 +1060,7 @@ func TestConflicts(t *testing.T) { collectionID := collection.GetCollectionID() allChannel := channels.NewID("all", collectionID) - _, err := db.changeCache.getChannelCache().getSingleChannelCache(allChannel) + _, err := db.changeCache.getChannelCache().getSingleChannelCache(ctx, allChannel) require.NoError(t, err) cacheWaiter := db.NewDCPCachingCountWaiter(t) @@ -1073,7 +1073,7 @@ func TestConflicts(t *testing.T) { // Wait for rev to be cached cacheWaiter.AddAndWait(1) - changeLog, err := collection.GetChangeLog(channels.NewID("all", collectionID), 0) + changeLog, err := collection.GetChangeLog(ctx, channels.NewID("all", collectionID), 0) require.NoError(t, err) assert.Equal(t, 1, len(changeLog)) @@ -1111,7 +1111,7 @@ func TestConflicts(t *testing.T) { // Verify the change-log of the "all" channel: cacheWaiter.Wait() - changeLog, err = collection.GetChangeLog(allChannel, 0) + changeLog, err = collection.GetChangeLog(ctx, allChannel, 0) require.NoError(t, err) assert.Equal(t, 1, len(changeLog)) assert.Equal(t, uint64(3), changeLog[0].Sequence) @@ -1122,7 +1122,7 @@ func TestConflicts(t *testing.T) { // Verify the _changes feed: options := ChangesOptions{ Conflicts: true, - ChangesCtx: context.Background(), + ChangesCtx: base.TestCtx(t), } changes, err := collection.GetChanges(ctx, channels.BaseSetOf(t, "all"), options) assert.NoError(t, err, "Couldn't GetChanges") @@ -1615,7 +1615,7 @@ func TestUpdateDesignDoc(t *testing.T) { assert.True(t, strings.Contains(retrievedView.Map, "emit()")) assert.NotEqual(t, mapFunction, retrievedView.Map) // SG should wrap the map function, so they shouldn't be equal - authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions()) + authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions(ctx)) db.user, _ = authenticator.NewUser("naomi", "letmein", channels.BaseSetOf(t, "Netflix")) err = db.PutDesignDoc("_design/pwn3d", sgbucket.DesignDoc{}) assertHTTPError(t, err, 403) @@ -1703,7 +1703,7 @@ func TestPostWithUserSpecialProperty(t *testing.T) { doc, err = collection.GetDocument(ctx, docid, DocUnmarshalAll) require.NotNil(t, doc) assert.Equal(t, rev2id, doc.CurrentRev) - assert.Equal(t, "value", doc.Body()["_special"]) + assert.Equal(t, "value", doc.Body(ctx)["_special"]) assert.NoError(t, err, "Unable to retrieve doc using generated uuid") } @@ -2167,8 +2167,8 @@ func TestConcurrentPushSameNewRevision(t *testing.T) { doc, err := collection.GetDocument(ctx, "doc1", DocUnmarshalAll) assert.Equal(t, revId, doc.RevID) assert.NoError(t, err, "Couldn't retrieve document") - assert.Equal(t, "Bob", doc.Body()["name"]) - assert.Equal(t, json.Number("52"), doc.Body()["age"]) + assert.Equal(t, "Bob", doc.Body(ctx)["name"]) + assert.Equal(t, json.Number("52"), doc.Body(ctx)["age"]) } // Multiple clients are attempting to push the same new, non-winning revision concurrently; non-winning is an @@ -2762,7 +2762,7 @@ func Test_invalidateAllPrincipalsCache(t *testing.T) { role, err := auth.NewRole(fmt.Sprintf("role%d", i), base.SetOf("ABC")) assert.NoError(t, err) assert.NotEmpty(t, role) - seq, err := db.sequences.nextSequence() + seq, err := db.sequences.nextSequence(ctx) assert.NoError(t, err) role.SetSequence(seq) err = auth.Save(role) @@ -2771,7 +2771,7 @@ func Test_invalidateAllPrincipalsCache(t *testing.T) { user, err := auth.NewUser(fmt.Sprintf("user%d", i), "letmein", base.SetOf("ABC")) assert.NoError(t, err) assert.NotEmpty(t, user) - seq, err = db.sequences.nextSequence() + seq, err = db.sequences.nextSequence(ctx) assert.NoError(t, err) user.SetSequence(seq) err = auth.Save(user) @@ -3006,10 +3006,9 @@ func TestImportCompactPanic(t *testing.T) { require.NoError(t, collection.WaitForPendingChanges(ctx)) // Wait for Compact to run - in the failing case it'll panic before incrementing the stat - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return db.DbStats.Database().NumTombstonesCompacted.Value() }, 1) - require.True(t, ok) } func TestGetDatabaseCollectionWithUserScopesNil(t *testing.T) { @@ -3230,14 +3229,15 @@ func Test_waitForBackgroundManagersToStop(t *testing.T) { name: "test_unstoppable_runner", Process: &testBackgroundProcess{isStoppable: false}, } - err := bgMngr.Start(context.TODO(), map[string]interface{}{}) + ctx := base.TestCtx(t) + err := bgMngr.Start(ctx, map[string]interface{}{}) require.NoError(t, err) err = bgMngr.Stop() require.NoError(t, err) startTime := time.Now() deadline := 10 * time.Second - waitForBackgroundManagersToStop(context.TODO(), deadline, []*BackgroundManager{bgMngr}) + waitForBackgroundManagersToStop(ctx, deadline, []*BackgroundManager{bgMngr}) assert.Greater(t, time.Since(startTime), deadline) assert.Equal(t, BackgroundProcessStateStopping, bgMngr.GetRunState()) }) @@ -3247,14 +3247,15 @@ func Test_waitForBackgroundManagersToStop(t *testing.T) { name: "test_stoppable_runner", Process: &testBackgroundProcess{isStoppable: true}, } - err := bgMngr.Start(context.TODO(), map[string]interface{}{}) + ctx := base.TestCtx(t) + err := bgMngr.Start(ctx, map[string]interface{}{}) require.NoError(t, err) err = bgMngr.Stop() require.NoError(t, err) startTime := time.Now() deadline := 10 * time.Second - waitForBackgroundManagersToStop(context.TODO(), deadline, []*BackgroundManager{bgMngr}) + waitForBackgroundManagersToStop(ctx, deadline, []*BackgroundManager{bgMngr}) assert.Less(t, time.Since(startTime), deadline) assert.Equal(t, BackgroundProcessStateStopped, bgMngr.GetRunState()) }) @@ -3264,7 +3265,8 @@ func Test_waitForBackgroundManagersToStop(t *testing.T) { name: "test_stoppable_runner", Process: &testBackgroundProcess{isStoppable: true}, } - err := stoppableBgMngr.Start(context.TODO(), map[string]interface{}{}) + ctx := base.TestCtx(t) + err := stoppableBgMngr.Start(ctx, map[string]interface{}{}) require.NoError(t, err) err = stoppableBgMngr.Stop() require.NoError(t, err) @@ -3274,14 +3276,14 @@ func Test_waitForBackgroundManagersToStop(t *testing.T) { Process: &testBackgroundProcess{isStoppable: false}, } - err = unstoppableBgMngr.Start(context.TODO(), map[string]interface{}{}) + err = unstoppableBgMngr.Start(ctx, map[string]interface{}{}) require.NoError(t, err) err = unstoppableBgMngr.Stop() require.NoError(t, err) startTime := time.Now() deadline := 10 * time.Second - waitForBackgroundManagersToStop(context.TODO(), deadline, []*BackgroundManager{stoppableBgMngr, unstoppableBgMngr}) + waitForBackgroundManagersToStop(ctx, deadline, []*BackgroundManager{stoppableBgMngr, unstoppableBgMngr}) assert.Greater(t, time.Since(startTime), deadline) assert.Equal(t, BackgroundProcessStateStopped, stoppableBgMngr.GetRunState()) assert.Equal(t, BackgroundProcessStateStopping, unstoppableBgMngr.GetRunState()) diff --git a/db/dcp_sharded_upgrade_test.go b/db/dcp_sharded_upgrade_test.go index 4d2f3f98b9..5c676a179e 100644 --- a/db/dcp_sharded_upgrade_test.go +++ b/db/dcp_sharded_upgrade_test.go @@ -236,7 +236,7 @@ func TestShardedDCPUpgrade(t *testing.T) { ctx = db.AddDatabaseLogContext(ctx) collection := GetSingleDatabaseCollection(t, db) - err, _ = base.RetryLoop("wait for non-existent node to be removed", func() (shouldRetry bool, err error, value interface{}) { + err, _ = base.RetryLoop(ctx, "wait for non-existent node to be removed", func() (shouldRetry bool, err error, value interface{}) { nodes, _, err := cbgt.CfgGetNodeDefs(db.CfgSG, cbgt.NODE_DEFS_KNOWN) if err != nil { return false, err, nil @@ -250,7 +250,7 @@ func TestShardedDCPUpgrade(t *testing.T) { }, base.CreateSleeperFunc(100, 100)) require.NoError(t, err) - err, _ = base.RetryLoop("wait for all pindexes to be reassigned", func() (shouldRetry bool, err error, value interface{}) { + err, _ = base.RetryLoop(ctx, "wait for all pindexes to be reassigned", func() (shouldRetry bool, err error, value interface{}) { pIndexes, _, err := cbgt.CfgGetPlanPIndexes(db.CfgSG) if err != nil { return false, nil, err diff --git a/db/design_doc.go b/db/design_doc.go index b72548fefc..bc19d613d9 100644 --- a/db/design_doc.go +++ b/db/design_doc.go @@ -622,7 +622,7 @@ func installViews(ctx context.Context, viewStore sgbucket.ViewStore) error { } description := fmt.Sprintf("Attempt to install Couchbase design doc") - err, _ := base.RetryLoop(description, worker, sleeper) + err, _ := base.RetryLoop(ctx, description, worker, sleeper) if err != nil { return pkgerrors.WithStack(base.RedactErrorf("Error installing Couchbase Design doc: %v. Error: %v", base.UD(designDocName), err)) @@ -787,7 +787,7 @@ func getViewStoreForDefaultCollection(dbContext *DatabaseContext) (sgbucket.View } vs, ok := base.AsViewStore(dbCollection.dataStore) if !ok { - return nil, fmt.Errorf("dbCollection.dataStore is not a ViewStore") + return nil, fmt.Errorf("%T is not a ViewStore", dbCollection.dataStore) } return vs, nil } diff --git a/db/document.go b/db/document.go index 625fbf0027..f16360397c 100644 --- a/db/document.go +++ b/db/document.go @@ -221,8 +221,8 @@ func (doc *Document) IsDeleted() bool { return doc.hasFlag(channels.Deleted) } -func (doc *Document) BodyWithSpecialProperties() ([]byte, error) { - bodyBytes, err := doc.BodyBytes() +func (doc *Document) BodyWithSpecialProperties(ctx context.Context) ([]byte, error) { + bodyBytes, err := doc.BodyBytes(ctx) if err != nil { return nil, err } @@ -254,26 +254,26 @@ func NewDocument(docid string) *Document { } // Accessors for document properties. To support lazy unmarshalling of document contents, all access should be done through accessors -func (doc *Document) Body() Body { +func (doc *Document) Body(ctx context.Context) Body { var caller string if base.ConsoleLogLevel().Enabled(base.LevelTrace) { caller = base.GetCallersName(1, true) } if doc._body != nil { - base.TracefCtx(context.Background(), base.KeyAll, "Already had doc body %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) + base.TracefCtx(ctx, base.KeyAll, "Already had doc body %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) return doc._body } if doc._rawBody == nil { - base.WarnfCtx(context.Background(), "Null doc body/rawBody %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) + base.WarnfCtx(ctx, "Null doc body/rawBody %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) return nil } - base.TracefCtx(context.Background(), base.KeyAll, " UNMARSHAL doc body %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) + base.TracefCtx(ctx, base.KeyAll, " UNMARSHAL doc body %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) err := doc._body.Unmarshal(doc._rawBody) if err != nil { - base.WarnfCtx(context.Background(), "Unable to unmarshal document body from raw body : %s", err) + base.WarnfCtx(ctx, "Unable to unmarshal document body from raw body : %s", err) return nil } return doc._body @@ -312,7 +312,7 @@ func (doc *Document) HasBody() bool { return doc._body != nil || doc._rawBody != nil } -func (doc *Document) BodyBytes() ([]byte, error) { +func (doc *Document) BodyBytes(ctx context.Context) ([]byte, error) { var caller string if base.ConsoleLogLevel().Enabled(base.LevelTrace) { caller = base.GetCallersName(1, true) @@ -323,7 +323,7 @@ func (doc *Document) BodyBytes() ([]byte, error) { } if doc._body == nil { - base.WarnfCtx(context.Background(), "Null doc body/rawBody %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) + base.WarnfCtx(ctx, "Null doc body/rawBody %s/%s from %s", base.UD(doc.ID), base.UD(doc.RevID), caller) return nil, nil } @@ -389,14 +389,14 @@ func unmarshalDocument(docid string, data []byte) (*Document, error) { return doc, nil } -func unmarshalDocumentWithXattr(docid string, data []byte, xattrData []byte, userXattrData []byte, cas uint64, unmarshalLevel DocumentUnmarshalLevel) (doc *Document, err error) { +func unmarshalDocumentWithXattr(ctx context.Context, docid string, data []byte, xattrData []byte, userXattrData []byte, cas uint64, unmarshalLevel DocumentUnmarshalLevel) (doc *Document, err error) { if xattrData == nil || len(xattrData) == 0 { // If no xattr data, unmarshal as standard doc doc, err = unmarshalDocument(docid, data) } else { doc = NewDocument(docid) - err = doc.UnmarshalWithXattr(data, xattrData, unmarshalLevel) + err = doc.UnmarshalWithXattr(ctx, data, xattrData, unmarshalLevel) } if err != nil { return nil, err @@ -468,7 +468,7 @@ func UnmarshalDocumentSyncDataFromFeed(data []byte, dataType uint8, userXattrKey return result, body, rawUserXattr, nil, err } -func UnmarshalDocumentFromFeed(docid string, cas uint64, data []byte, dataType uint8, userXattrKey string) (doc *Document, err error) { +func UnmarshalDocumentFromFeed(ctx context.Context, docid string, cas uint64, data []byte, dataType uint8, userXattrKey string) (doc *Document, err error) { var body []byte if dataType&base.MemcachedDataTypeXattr != 0 { @@ -478,7 +478,7 @@ func UnmarshalDocumentFromFeed(docid string, cas uint64, data []byte, dataType u if err != nil { return nil, err } - return unmarshalDocumentWithXattr(docid, body, syncXattr, userXattr, cas, DocUnmarshalAll) + return unmarshalDocumentWithXattr(ctx, docid, body, syncXattr, userXattr, cas, DocUnmarshalAll) } return unmarshalDocument(docid, data) @@ -622,7 +622,7 @@ func (doc *Document) IsSGWrite(ctx context.Context, rawBody []byte) (isSGWrite b } // Since raw body isn't available, marshal from the document to perform body hash comparison - bodyBytes, err := doc.BodyBytes() + bodyBytes, err := doc.BodyBytes(ctx) if err != nil { base.WarnfCtx(ctx, "Unable to marshal doc body during SG write check for doc %s. Error: %v", base.UD(doc.ID), err) return false, false, false @@ -678,19 +678,19 @@ func (c *DatabaseCollection) RevisionBodyLoader(key string) ([]byte, error) { } // Fetches the body of a revision as a map, or nil if it's not available. -func (doc *Document) getRevisionBody(revid string, loader RevLoaderFunc) Body { +func (doc *Document) getRevisionBody(ctx context.Context, revid string, loader RevLoaderFunc) Body { var body Body if revid == doc.CurrentRev { - body = doc.Body() + body = doc.Body(ctx) } else { - body = doc.getNonWinningRevisionBody(revid, loader) + body = doc.getNonWinningRevisionBody(ctx, revid, loader) } return body } // Retrieves a non-winning revision body. If not already loaded in the document (either because inline, // or was previously requested), loader function is used to retrieve from the bucket. -func (doc *Document) getNonWinningRevisionBody(revid string, loader RevLoaderFunc) Body { +func (doc *Document) getNonWinningRevisionBody(ctx context.Context, revid string, loader RevLoaderFunc) Body { var body Body bodyBytes, found := doc.History.getRevisionBody(revid, loader) if !found || len(bodyBytes) == 0 { @@ -698,7 +698,7 @@ func (doc *Document) getNonWinningRevisionBody(revid string, loader RevLoaderFun } if err := body.Unmarshal(bodyBytes); err != nil { - base.WarnfCtx(context.TODO(), "Unexpected error parsing body of rev %q: %v", revid, err) + base.WarnfCtx(ctx, "Unexpected error parsing body of rev %q: %v", revid, err) return nil } return body @@ -709,7 +709,7 @@ func (doc *Document) getRevisionBodyJSON(ctx context.Context, revid string, load var bodyJSON []byte if revid == doc.CurrentRev { var marshalErr error - bodyJSON, marshalErr = doc.BodyBytes() + bodyJSON, marshalErr = doc.BodyBytes(ctx) if marshalErr != nil { base.WarnfCtx(ctx, "Marshal error when retrieving active current revision body: %v", marshalErr) } @@ -719,8 +719,8 @@ func (doc *Document) getRevisionBodyJSON(ctx context.Context, revid string, load return bodyJSON } -func (doc *Document) removeRevisionBody(revID string) { - removedBodyKey := doc.History.removeRevisionBody(revID) +func (doc *Document) removeRevisionBody(ctx context.Context, revID string) { + removedBodyKey := doc.History.removeRevisionBody(ctx, revID) if removedBodyKey != "" { if doc.removedRevisionBodyKeys == nil { doc.removedRevisionBodyKeys = make(map[string]string) @@ -730,15 +730,15 @@ func (doc *Document) removeRevisionBody(revID string) { } // makeBodyActive moves a previously non-winning revision body from the rev tree to the document body -func (doc *Document) promoteNonWinningRevisionBody(revid string, loader RevLoaderFunc) { +func (doc *Document) promoteNonWinningRevisionBody(ctx context.Context, revid string, loader RevLoaderFunc) { // If the new revision is not current, transfer the current revision's // body to the top level doc._body: - doc.UpdateBody(doc.getNonWinningRevisionBody(revid, loader)) - doc.removeRevisionBody(revid) + doc.UpdateBody(doc.getNonWinningRevisionBody(ctx, revid, loader)) + doc.removeRevisionBody(ctx, revid) } -func (doc *Document) pruneRevisions(maxDepth uint32, keepRev string) int { - numPruned, prunedTombstoneBodyKeys := doc.History.pruneRevisions(maxDepth, keepRev) +func (doc *Document) pruneRevisions(ctx context.Context, maxDepth uint32, keepRev string) int { + numPruned, prunedTombstoneBodyKeys := doc.History.pruneRevisions(ctx, maxDepth, keepRev) for revID, bodyKey := range prunedTombstoneBodyKeys { if doc.removedRevisionBodyKeys == nil { doc.removedRevisionBodyKeys = make(map[string]string) @@ -749,12 +749,12 @@ func (doc *Document) pruneRevisions(maxDepth uint32, keepRev string) int { } // Adds a revision body (as Body) to a document. Removes special properties first. -func (doc *Document) setRevisionBody(revid string, newDoc *Document, storeInline, hasAttachments bool) { +func (doc *Document) setRevisionBody(ctx context.Context, revid string, newDoc *Document, storeInline, hasAttachments bool) { if revid == doc.CurrentRev { doc._body = newDoc._body doc._rawBody = newDoc._rawBody } else { - bodyBytes, _ := newDoc.BodyBytes() + bodyBytes, _ := newDoc.BodyBytes(ctx) doc.setNonWinningRevisionBody(revid, bodyBytes, storeInline, hasAttachments) } } @@ -802,12 +802,12 @@ func (doc *Document) persistModifiedRevisionBodies(datastore sgbucket.DataStore) // deleteRemovedRevisionBodies deletes obsolete non-inline revisions from the bucket. // Should be invoked AFTER the document is successfully committed. -func (doc *Document) deleteRemovedRevisionBodies(dataStore base.DataStore) { +func (doc *Document) deleteRemovedRevisionBodies(ctx context.Context, dataStore base.DataStore) { for _, revBodyKey := range doc.removedRevisionBodyKeys { deleteErr := dataStore.Delete(revBodyKey) if deleteErr != nil { - base.WarnfCtx(context.TODO(), "Unable to delete old revision body using key %s - will not be deleted from bucket.", revBodyKey) + base.WarnfCtx(ctx, "Unable to delete old revision body using key %s - will not be deleted from bucket.", revBodyKey) } } doc.removedRevisionBodyKeys = map[string]string{} @@ -819,7 +819,7 @@ func (doc *Document) persistRevisionBody(datastore sgbucket.DataStore, key strin } // Move any large revision bodies to external document storage -func (doc *Document) migrateRevisionBodies(dataStore base.DataStore) error { +func (doc *Document) migrateRevisionBodies(ctx context.Context, dataStore base.DataStore) error { for _, revID := range doc.History.GetLeaves() { revInfo, err := doc.History.getInfo(revID) @@ -830,7 +830,7 @@ func (doc *Document) migrateRevisionBodies(dataStore base.DataStore) error { bodyKey := generateRevBodyKey(doc.ID, revID) persistErr := doc.persistRevisionBody(dataStore, bodyKey, revInfo.Body) if persistErr != nil { - base.WarnfCtx(context.TODO(), "Unable to store revision body for doc %s, rev %s externally: %v", base.UD(doc.ID), revID, persistErr) + base.WarnfCtx(ctx, "Unable to store revision body for doc %s, rev %s externally: %v", base.UD(doc.ID), revID, persistErr) continue } revInfo.BodyKey = bodyKey @@ -992,7 +992,7 @@ func (doc *Document) updateChannels(ctx context.Context, newChannels base.Set) ( // Determine whether the specified revision was a channel removal, based on doc.Channels. If so, construct the standard document body for a // removal notification (_removed=true) // Set of channels returned from IsChannelRemoval are "Active" channels and NOT "Removed". -func (doc *Document) IsChannelRemoval(revID string) (bodyBytes []byte, history Revisions, channels base.Set, isRemoval bool, isDelete bool, err error) { +func (doc *Document) IsChannelRemoval(ctx context.Context, revID string) (bodyBytes []byte, history Revisions, channels base.Set, isRemoval bool, isDelete bool, err error) { removedChannels := make(base.Set) @@ -1032,14 +1032,14 @@ func (doc *Document) IsChannelRemoval(revID string) (bodyBytes []byte, history R if len(revHistory) == 0 { revHistory = []string{revID} } - history = encodeRevisions(doc.ID, revHistory) + history = encodeRevisions(ctx, doc.ID, revHistory) return bodyBytes, history, activeChannels, true, isDelete, nil } // Updates a document's channel/role UserAccessMap with new access settings from an AccessMap. // Returns an array of the user/role names whose access has changed as a result. -func (accessMap *UserAccessMap) updateAccess(doc *Document, newAccess channels.AccessMap) (changedUsers []string) { +func (accessMap *UserAccessMap) updateAccess(ctx context.Context, doc *Document, newAccess channels.AccessMap) (changedUsers []string) { // Update users already appearing in doc.Access: for name, access := range *accessMap { if access.UpdateAtSequence(newAccess[name], doc.Sequence) { @@ -1064,7 +1064,7 @@ func (accessMap *UserAccessMap) updateAccess(doc *Document, newAccess channels.A if accessMap == &doc.RoleAccess { what = "role" } - base.InfofCtx(context.TODO(), base.KeyAccess, "Doc %q grants %s access: %v", base.UD(doc.ID), what, base.UD(*accessMap)) + base.InfofCtx(ctx, base.KeyAccess, "Doc %q grants %s access: %v", base.UD(doc.ID), what, base.UD(*accessMap)) } return changedUsers } @@ -1126,9 +1126,9 @@ func (doc *Document) MarshalJSON() (data []byte, err error) { // (unmarshalLevel) specifies how much of the provided document/xattr needs to be initially unmarshalled. If // unmarshalLevel is anything less than the full document + metadata, the raw data is retained for subsequent // lazy unmarshalling as needed. -func (doc *Document) UnmarshalWithXattr(data []byte, xdata []byte, unmarshalLevel DocumentUnmarshalLevel) error { +func (doc *Document) UnmarshalWithXattr(ctx context.Context, data []byte, xdata []byte, unmarshalLevel DocumentUnmarshalLevel) error { if doc.ID == "" { - base.WarnfCtx(context.Background(), "Attempted to unmarshal document without ID set") + base.WarnfCtx(ctx, "Attempted to unmarshal document without ID set") return errors.New("Document was unmarshalled without ID set") } diff --git a/db/document_test.go b/db/document_test.go index e15d49ca0a..ba366dd435 100644 --- a/db/document_test.go +++ b/db/document_test.go @@ -132,8 +132,9 @@ func BenchmarkDocUnmarshal(b *testing.B) { for _, bm := range unmarshalBenchmarks { b.Run(bm.name, func(b *testing.B) { + ctx := base.TestCtx(b) for i := 0; i < b.N; i++ { - _, _ = unmarshalDocumentWithXattr("doc_1k", doc1k_body, doc1k_meta, nil, 1, bm.unmarshalLevel) + _, _ = unmarshalDocumentWithXattr(ctx, "doc_1k", doc1k_body, doc1k_meta, nil, 1, bm.unmarshalLevel) } }) } @@ -154,6 +155,7 @@ func BenchmarkUnmarshalBody(b *testing.B) { for _, bm := range unmarshalBenchmarks { b.Run(bm.name, func(b *testing.B) { + ctx := base.TestCtx(b) for i := 0; i < b.N; i++ { b.StopTimer() doc := NewDocument("testDocID") @@ -170,7 +172,7 @@ func BenchmarkUnmarshalBody(b *testing.B) { } else { err = base.JSONUnmarshal(doc1k_body, &doc._body) if bm.fixJSONNumbers { - doc.Body().FixJSONNumbers() + doc.Body(ctx).FixJSONNumbers() } } b.StopTimer() @@ -178,7 +180,7 @@ func BenchmarkUnmarshalBody(b *testing.B) { log.Printf("Unmarshal error: %s", err) } - if len(doc.Body()) == 0 { + if len(doc.Body(ctx)) == 0 { log.Printf("Empty body") } diff --git a/db/event.go b/db/event.go index 55e169b246..6d726ee357 100644 --- a/db/event.go +++ b/db/event.go @@ -112,13 +112,13 @@ type jsEventTask struct { } // Compiles a JavaScript event function to a jsEventTask object. -func newJsEventTask(funcSource string) (sgbucket.JSServerTask, error) { +func newJsEventTask(ctx context.Context, funcSource string) (sgbucket.JSServerTask, error) { eventTask := &jsEventTask{} err := eventTask.InitWithLogging(funcSource, 0, func(s string) { - base.ErrorfCtx(context.Background(), base.KeyJavascript.String()+": Webhook %s", base.UD(s)) + base.ErrorfCtx(ctx, base.KeyJavascript.String()+": Webhook %s", base.UD(s)) }, - func(s string) { base.InfofCtx(context.Background(), base.KeyJavascript, "Webhook %s", base.UD(s)) }) + func(s string) { base.InfofCtx(ctx, base.KeyJavascript, "Webhook %s", base.UD(s)) }) if err != nil { return nil, err } @@ -153,13 +153,13 @@ type JSEventFunction struct { *sgbucket.JSServer } -func NewJSEventFunction(fnSource string) *JSEventFunction { +func NewJSEventFunction(ctx context.Context, fnSource string) *JSEventFunction { - base.InfofCtx(context.Background(), base.KeyEvents, "Creating new JSEventFunction") + base.InfofCtx(ctx, base.KeyEvents, "Creating new JSEventFunction") return &JSEventFunction{ JSServer: sgbucket.NewJSServer(fnSource, 0, kTaskCacheSize, func(fnSource string, timeout time.Duration) (sgbucket.JSServerTask, error) { - return newJsEventTask(fnSource) + return newJsEventTask(ctx, fnSource) }), } } @@ -167,6 +167,8 @@ func NewJSEventFunction(fnSource string) *JSEventFunction { // Calls a jsEventFunction returning an interface{} func (ef *JSEventFunction) CallFunction(event Event) (interface{}, error) { + ctx := context.TODO() // fix in sg-bucket + var err error var result interface{} @@ -178,12 +180,12 @@ func (ef *JSEventFunction) CallFunction(event Event) (interface{}, error) { case *DBStateChangeEvent: result, err = ef.Call(event.Doc) default: - base.WarnfCtx(context.TODO(), "unknown event %v tried to call function", event.EventType()) + base.WarnfCtx(ctx, "unknown event %v tried to call function", event.EventType()) return "", fmt.Errorf("unknown event %v tried to call function", event.EventType()) } if err != nil { - base.WarnfCtx(context.TODO(), "Error calling function - function processing aborted: %v", err) + base.WarnfCtx(ctx, "Error calling function - function processing aborted: %v", err) return "", err } diff --git a/db/event_handler.go b/db/event_handler.go index 7bf8102436..6b7683b159 100644 --- a/db/event_handler.go +++ b/db/event_handler.go @@ -24,7 +24,7 @@ import ( // EventHandler interface represents an instance of an event handler defined in the database config type EventHandler interface { - HandleEvent(event Event) bool + HandleEvent(ctx context.Context, event Event) bool String() string } @@ -50,7 +50,7 @@ const ( ) // Creates a new webhook handler based on the url and filter function. -func NewWebhook(url string, filterFnString string, timeout *uint64, options map[string]interface{}) (*Webhook, error) { +func NewWebhook(ctx context.Context, url string, filterFnString string, timeout *uint64, options map[string]interface{}) (*Webhook, error) { var err error @@ -63,7 +63,7 @@ func NewWebhook(url string, filterFnString string, timeout *uint64, options map[ url: url, } if filterFnString != "" { - wh.filter = NewJSEventFunction(filterFnString) + wh.filter = NewJSEventFunction(ctx, filterFnString) } if timeout != nil { @@ -87,11 +87,10 @@ func NewWebhook(url string, filterFnString string, timeout *uint64, options map[ // Performs an HTTP POST to the url defined for the handler. If a filter function is defined, // calls it to determine whether to POST. The payload for the POST is depends // on the event type. -func (wh *Webhook) HandleEvent(event Event) bool { +func (wh *Webhook) HandleEvent(ctx context.Context, event Event) bool { const contentType = "application/json" var payload []byte - logCtx := context.TODO() // Different events post different content by default switch event := event.(type) { @@ -112,12 +111,12 @@ func (wh *Webhook) HandleEvent(event Event) bool { //} jsonOut, err := base.JSONMarshal(event.Doc) if err != nil { - base.WarnfCtx(logCtx, "Error marshalling doc for webhook post") + base.WarnfCtx(ctx, "Error marshalling doc for webhook post") return false } payload = jsonOut default: - base.WarnfCtx(logCtx, "Webhook invoked for unsupported event type.") + base.WarnfCtx(ctx, "Webhook invoked for unsupported event type.") return false } @@ -125,7 +124,7 @@ func (wh *Webhook) HandleEvent(event Event) bool { // If filter function is defined, use it to determine whether to post success, err := wh.filter.CallValidateFunction(event) if err != nil { - base.WarnfCtx(logCtx, "Error calling webhook filter function: %v", err) + base.WarnfCtx(ctx, "Error calling webhook filter function: %v", err) } // If filter returns false, cancel webhook post @@ -141,24 +140,24 @@ func (wh *Webhook) HandleEvent(event Event) bool { if resp != nil && resp.Body != nil { _, err := io.Copy(io.Discard, resp.Body) if err != nil { - base.DebugfCtx(logCtx, base.KeyEvents, "Error copying response body: %v", err) + base.DebugfCtx(ctx, base.KeyEvents, "Error copying response body: %v", err) } err = resp.Body.Close() if err != nil { - base.DebugfCtx(logCtx, base.KeyEvents, "Error closing response body: %v", err) + base.DebugfCtx(ctx, base.KeyEvents, "Error closing response body: %v", err) } } }() if err != nil { - base.WarnfCtx(logCtx, "Error attempting to post %s to url %s: %s", base.UD(event.String()), base.UD(wh.SanitizedUrl()), err) + base.WarnfCtx(ctx, "Error attempting to post %s to url %s: %s", base.UD(event.String()), base.UD(wh.SanitizedUrl(ctx)), err) return false } // Check Log Level first, as SanitizedUrl is expensive to evaluate. if base.LogDebugEnabled(base.KeyEvents) { - base.DebugfCtx(logCtx, base.KeyEvents, "Webhook handler ran for event. Payload %s posted to URL %s, got status %s", - base.UD(string(payload)), base.UD(wh.SanitizedUrl()), resp.Status) + base.DebugfCtx(ctx, base.KeyEvents, "Webhook handler ran for event. Payload %s posted to URL %s, got status %s", + base.UD(string(payload)), base.UD(wh.SanitizedUrl(ctx)), resp.Status) } return true }() @@ -166,10 +165,10 @@ func (wh *Webhook) HandleEvent(event Event) bool { } func (wh *Webhook) String() string { - return fmt.Sprintf("Webhook handler [%s]", wh.SanitizedUrl()) + return fmt.Sprintf("Webhook handler [%s]", wh.SanitizedUrl(context.TODO())) // not possible to provide a better context and satisfy fmt.Stringer } -func (wh *Webhook) SanitizedUrl() string { +func (wh *Webhook) SanitizedUrl(ctx context.Context) string { // Basic auth credentials may have been included in the URL, in which case obscure them - return base.RedactBasicAuthURLUserAndPassword(wh.url) + return base.RedactBasicAuthURLUserAndPassword(ctx, wh.url) } diff --git a/db/event_handler_test.go b/db/event_handler_test.go index 1ad96e29ae..c10735e41f 100644 --- a/db/event_handler_test.go +++ b/db/event_handler_test.go @@ -33,16 +33,16 @@ func TestWebhookString(t *testing.T) { func TestSanitizedUrl(t *testing.T) { var wh *Webhook - + ctx := base.TestCtx(t) wh = &Webhook{ url: "https://foo%40bar.baz:my-%24ecret-p%40%25%24w0rd@example.com:8888/bar", } - assert.Equal(t, "https://xxxxx:xxxxx@example.com:8888/bar", wh.SanitizedUrl()) + assert.Equal(t, "https://xxxxx:xxxxx@example.com:8888/bar", wh.SanitizedUrl(ctx)) wh = &Webhook{ url: "https://example.com/does-not-count-as-url-embedded:basic-auth-credentials@qux", } - assert.Equal(t, "https://example.com/does-not-count-as-url-embedded:basic-auth-credentials@qux", wh.SanitizedUrl()) + assert.Equal(t, "https://example.com/does-not-count-as-url-embedded:basic-auth-credentials@qux", wh.SanitizedUrl(ctx)) } func TestCallValidateFunction(t *testing.T) { @@ -53,30 +53,31 @@ func TestCallValidateFunction(t *testing.T) { bodyBytes, _ := base.JSONMarshal(body) event := &DocumentChangeEvent{DocID: docId, DocBytes: bodyBytes, OldDoc: oldBodyJSON, Channels: channels} + ctx := base.TestCtx(t) // Boolean return type handling of CallValidateFunction; bool true value. source := `function(doc) { if (doc.key1 == "value1") { return true; } else { return false; } }` - filterFunc := NewJSEventFunction(source) + filterFunc := NewJSEventFunction(ctx, source) result, err := filterFunc.CallValidateFunction(event) assert.True(t, result, "It should return true since doc.key1 is value1") assert.NoError(t, err, "It should return boolean result") // Boolean return type handling of CallValidateFunction; bool false value. source = `function(doc) { if (doc.key1 == "value2") { return true; } else { return false; } }` - filterFunc = NewJSEventFunction(source) + filterFunc = NewJSEventFunction(ctx, source) result, err = filterFunc.CallValidateFunction(event) assert.False(t, result, "It should return false since doc.key1 is not value2") assert.NoError(t, err, "It should return boolean result") // Parsable boolean string return type handling of CallValidateFunction. source = `function(doc) { if (doc.key1 == "value1") { return "true"; } else { return "false"; } }` - filterFunc = NewJSEventFunction(source) + filterFunc = NewJSEventFunction(ctx, source) result, err = filterFunc.CallValidateFunction(event) assert.True(t, result, "It should return true since doc.key1 is value1") assert.NoError(t, err, "It should return parsable boolean result") // Non parsable boolean string return type handling of CallValidateFunction. source = `function(doc) { if (doc.key1 == "value1") { return "TrUe"; } else { return "false"; } }` - filterFunc = NewJSEventFunction(source) + filterFunc = NewJSEventFunction(ctx, source) result, err = filterFunc.CallValidateFunction(event) assert.False(t, result, "It should return false since 'TrUe' is non parsable boolean string") assert.Error(t, err, "It should return parsable throw ParseBool error") @@ -84,7 +85,7 @@ func TestCallValidateFunction(t *testing.T) { // Not boolean and not parsable boolean string return type handling of CallValidateFunction. source = `function(doc) { if (doc.key1 == "Pi") { return 3.14; } else { return 0.0; } }` - filterFunc = NewJSEventFunction(source) + filterFunc = NewJSEventFunction(ctx, source) result, err = filterFunc.CallValidateFunction(event) assert.False(t, result, "It should return not boolean and not parsable boolean string value") assert.Error(t, err, "It should throw Validate function returned non-boolean value error") @@ -92,7 +93,7 @@ func TestCallValidateFunction(t *testing.T) { // Simulate CallFunction failure by making syntax error in filter function. source = `function(doc) { invalidKeyword if (doc.key1 == "value1") { return true; } else { return false; } }` - filterFunc = NewJSEventFunction(source) + filterFunc = NewJSEventFunction(ctx, source) result, err = filterFunc.CallValidateFunction(event) assert.False(t, result, "It should return false due to the syntax error in filter function") assert.Error(t, err, "It should throw an error due to syntax error") diff --git a/db/event_manager.go b/db/event_manager.go index a388afd5fd..6c3b137a29 100644 --- a/db/event_manager.go +++ b/db/event_manager.go @@ -64,7 +64,7 @@ func NewEventManager(terminator chan bool) *EventManager { } // Starts the listener queue for the event manager -func (em *EventManager) Start(maxProcesses uint, waitTime int) { +func (em *EventManager) Start(ctx context.Context, maxProcesses uint, waitTime int) { if maxProcesses == 0 { maxProcesses = kMaxActiveEvents @@ -75,7 +75,7 @@ func (em *EventManager) Start(maxProcesses uint, waitTime int) { em.waitTime = waitTime } - base.InfofCtx(context.TODO(), base.KeyEvents, "Starting event manager with max processes:%d, wait time:%d ms", maxProcesses, em.waitTime) + base.InfofCtx(ctx, base.KeyEvents, "Starting event manager with max processes:%d, wait time:%d ms", maxProcesses, em.waitTime) // activeCountChannel limits the number of concurrent events being processed em.activeCountChannel = make(chan bool, maxProcesses) @@ -93,7 +93,7 @@ func (em *EventManager) Start(maxProcesses uint, waitTime int) { return case event := <-em.asyncEventChannel: em.activeCountChannel <- true - go em.ProcessEvent(event) + go em.ProcessEvent(ctx, event) } } }() @@ -101,25 +101,24 @@ func (em *EventManager) Start(maxProcesses uint, waitTime int) { } // Concurrent processing of all async event handlers registered for the event type -func (em *EventManager) ProcessEvent(event Event) { +func (em *EventManager) ProcessEvent(ctx context.Context, event Event) { defer func() { <-em.activeCountChannel }() - logCtx := context.TODO() // Send event to all registered handlers concurrently. WaitGroup blocks // until all are finished var wg sync.WaitGroup for _, handler := range em.eventHandlers[event.EventType()] { - base.DebugfCtx(logCtx, base.KeyEvents, "Event queue worker sending event %s to: %s", base.UD(event.String()), handler) + base.DebugfCtx(ctx, base.KeyEvents, "Event queue worker sending event %s to: %s", base.UD(event.String()), handler) wg.Add(1) go func(event Event, handler EventHandler) { defer wg.Done() //TODO: Currently we're not tracking success/fail from event handlers. When this // is needed, could pass a channel to HandleEvent for tracking results - if handler.HandleEvent(event) { + if handler.HandleEvent(ctx, event) { em.IncrementEventsProcessedSuccess(1) } else { em.IncrementEventsProcessedFail(1) } - base.TracefCtx(logCtx, base.KeyAll, "Webhook event processed %s", event) + base.TracefCtx(ctx, base.KeyAll, "Webhook event processed %s", event) }(event, handler) } @@ -128,10 +127,10 @@ func (em *EventManager) ProcessEvent(event Event) { // Register a new event handler to the EventManager. The event manager will route events of // type eventType to the handler. -func (em *EventManager) RegisterEventHandler(handler EventHandler, eventType EventType) { +func (em *EventManager) RegisterEventHandler(ctx context.Context, handler EventHandler, eventType EventType) { em.eventHandlers[eventType] = append(em.eventHandlers[eventType], handler) em.activeEventTypes[eventType] = true - base.InfofCtx(context.Background(), base.KeyEvents, "Registered event handler: %v, for event type %v", handler, eventType) + base.InfofCtx(ctx, base.KeyEvents, "Registered event handler: %v, for event type %v", handler, eventType) } // Checks whether a handler of the given type has been registered to the event manager. @@ -140,7 +139,7 @@ func (em *EventManager) HasHandlerForEvent(eventType EventType) bool { } // Adds async events to the channel for processing -func (em *EventManager) raiseEvent(event Event) error { +func (em *EventManager) raiseEvent(ctx context.Context, event Event) error { if !event.Synchronous() { // When asyncEventChannel is full, the raiseEvent method will block for (waitTime). // Default value of (waitTime) is 5 ms. @@ -148,10 +147,10 @@ func (em *EventManager) raiseEvent(event Event) error { defer timer.Stop() select { case em.asyncEventChannel <- event: - base.TracefCtx(context.TODO(), base.KeyAll, "Event sent to channel %s", event.String()) + base.TracefCtx(ctx, base.KeyAll, "Event sent to channel %s", event.String()) case <-timer.C: // Event queue channel is full - ignore event and log error - base.WarnfCtx(context.TODO(), "Event queue full - discarding event: %s", base.UD(event.String())) + base.WarnfCtx(ctx, "Event queue full - discarding event: %s", base.UD(event.String())) return errors.New("Event queue full") } } @@ -161,7 +160,7 @@ func (em *EventManager) raiseEvent(event Event) error { // Raises a document change event based on the the document body and channel set. If the // event manager doesn't have a listener for this event, ignores. -func (em *EventManager) RaiseDocumentChangeEvent(docBytes []byte, docID string, oldBodyJSON string, channels base.Set, winningRevChange bool) error { +func (em *EventManager) RaiseDocumentChangeEvent(ctx context.Context, docBytes []byte, docID string, oldBodyJSON string, channels base.Set, winningRevChange bool) error { if !em.activeEventTypes[DocumentChange] { return nil @@ -174,13 +173,13 @@ func (em *EventManager) RaiseDocumentChangeEvent(docBytes []byte, docID string, WinningRevChange: winningRevChange, } - return em.raiseEvent(event) + return em.raiseEvent(ctx, event) } // Raises a DB state change event based on the db name, admininterface, new state, reason and local system time. // If the event manager doesn't have a listener for this event, ignores. -func (em *EventManager) RaiseDBStateChangeEvent(dbName string, state string, reason string, adminInterface *string) error { +func (em *EventManager) RaiseDBStateChangeEvent(ctx context.Context, dbName string, state string, reason string, adminInterface *string) error { if !em.activeEventTypes[DBStateChange] { return nil @@ -202,5 +201,5 @@ func (em *EventManager) RaiseDBStateChangeEvent(dbName string, state string, rea Doc: body, } - return em.raiseEvent(event) + return em.raiseEvent(ctx, event) } diff --git a/db/event_manager_test.go b/db/event_manager_test.go index 4ec09bb134..40b5865822 100644 --- a/db/event_manager_test.go +++ b/db/event_manager_test.go @@ -40,7 +40,7 @@ type TestingHandler struct { t *testing.T // enclosing test instance } -func (th *TestingHandler) HandleEvent(event Event) bool { +func (th *TestingHandler) HandleEvent(_ context.Context, event Event) bool { if th.handleDelay > 0 { time.Sleep(time.Duration(th.handleDelay) * time.Millisecond) @@ -82,11 +82,12 @@ func (th *TestingHandler) String() string { } func TestDocumentChangeEvent(t *testing.T) { + ctx := base.TestCtx(t) terminator := make(chan bool) defer close(terminator) em := NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) // Setup test data ids := make([]string, 20) @@ -110,12 +111,12 @@ func TestDocumentChangeEvent(t *testing.T) { // Setup test handler testHandler := &TestingHandler{HandledEvent: DocumentChange} testHandler.SetChannel(resultChannel) - em.RegisterEventHandler(testHandler, DocumentChange) + em.RegisterEventHandler(ctx, testHandler, DocumentChange) // Raise events for i := 0; i < 10; i++ { body, docid, channels := eventForTest(i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) assert.NoError(t, err) } @@ -124,11 +125,12 @@ func TestDocumentChangeEvent(t *testing.T) { } func TestDBStateChangeEvent(t *testing.T) { + ctx := base.TestCtx(t) terminator := make(chan bool) defer close(terminator) em := NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) // Setup test data ids := make([]string, 20) @@ -140,15 +142,15 @@ func TestDBStateChangeEvent(t *testing.T) { // Setup test handler testHandler := &TestingHandler{HandledEvent: DBStateChange, t: t} testHandler.SetChannel(resultChannel) - em.RegisterEventHandler(testHandler, DBStateChange) + em.RegisterEventHandler(ctx, testHandler, DBStateChange) // Raise online events for i := 0; i < 10; i++ { - err := em.RaiseDBStateChangeEvent(ids[i], "online", "DB started from config", base.StringPtr("0.0.0.0:0000")) + err := em.RaiseDBStateChangeEvent(ctx, ids[i], "online", "DB started from config", base.StringPtr("0.0.0.0:0000")) assert.NoError(t, err) } // Raise offline events for i := 10; i < 20; i++ { - err := em.RaiseDBStateChangeEvent(ids[i], "offline", "Sync Gateway context closed", base.StringPtr("0.0.0.0:0000")) + err := em.RaiseDBStateChangeEvent(ctx, ids[i], "offline", "Sync Gateway context closed", base.StringPtr("0.0.0.0:0000")) assert.NoError(t, err) } @@ -165,12 +167,13 @@ func TestDBStateChangeEvent(t *testing.T) { // Test sending many events with slow-running execution to validate they get dropped after hitting // the max concurrent goroutines func TestSlowExecutionProcessing(t *testing.T) { + ctx := base.TestCtx(t) terminator := make(chan bool) defer close(terminator) base.SetUpTestLogging(t, base.LevelInfo, base.KeyEvents) em := NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) ids := make([]string, 20) for i := 0; i < 20; i++ { @@ -194,12 +197,12 @@ func TestSlowExecutionProcessing(t *testing.T) { resultChannel := make(chan interface{}, 100) testHandler := &TestingHandler{HandledEvent: DocumentChange, handleDelay: 500} testHandler.SetChannel(resultChannel) - em.RegisterEventHandler(testHandler, DocumentChange) + em.RegisterEventHandler(ctx, testHandler, DocumentChange) for i := 0; i < 20; i++ { body, docid, channels := eventForTest(i % 10) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) assert.NoError(t, err) } @@ -208,11 +211,12 @@ func TestSlowExecutionProcessing(t *testing.T) { } func TestCustomHandler(t *testing.T) { + ctx := base.TestCtx(t) terminator := make(chan bool) defer close(terminator) em := NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) ids := make([]string, 20) for i := 0; i < 20; i++ { @@ -237,12 +241,12 @@ func TestCustomHandler(t *testing.T) { testHandler := &TestingHandler{HandledEvent: DocumentChange} testHandler.SetChannel(resultChannel) - em.RegisterEventHandler(testHandler, DocumentChange) + em.RegisterEventHandler(ctx, testHandler, DocumentChange) for i := 0; i < 10; i++ { body, docid, channels := eventForTest(i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) assert.NoError(t, err) } @@ -251,11 +255,12 @@ func TestCustomHandler(t *testing.T) { } func TestUnhandledEvent(t *testing.T) { + ctx := base.TestCtx(t) terminator := make(chan bool) defer close(terminator) em := NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) ids := make([]string, 20) for i := 0; i < 20; i++ { @@ -281,13 +286,13 @@ func TestUnhandledEvent(t *testing.T) { // create handler for an unhandled event testHandler := &TestingHandler{HandledEvent: math.MaxUint8} testHandler.SetChannel(resultChannel) - em.RegisterEventHandler(testHandler, math.MaxUint8) + em.RegisterEventHandler(ctx, testHandler, math.MaxUint8) // send DocumentChange events to handler for i := 0; i < 10; i++ { body, docid, channels := eventForTest(i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) assert.NoError(t, err) } @@ -371,7 +376,7 @@ func (em *EventManager) waitForProcessedTotal(ctx context.Context, waitCount int ctx, cancel := context.WithDeadline(ctx, startTime.Add(maxWaitTime)) sleeper := base.SleeperFuncCtx(base.CreateMaxDoublingSleeperFunc(math.MaxInt64, 1, 1000), ctx) - err, _ := base.RetryLoop(fmt.Sprintf("waitForProcessedTotal(%d)", waitCount), worker, sleeper) + err, _ := base.RetryLoop(ctx, fmt.Sprintf("waitForProcessedTotal(%d)", waitCount), worker, sleeper) cancel() return err } @@ -463,16 +468,17 @@ func TestWebhookBasic(t *testing.T) { // Test basic webhook log.Println("Test basic webhook") em := NewEventManager(terminator) - em.Start(0, -1) - webhookHandler, _ := NewWebhook(fmt.Sprintf("%s/echo", url), "", nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + ctx := base.TestCtx(t) + em.Start(ctx, 0, -1) + webhookHandler, _ := NewWebhook(ctx, fmt.Sprintf("%s/echo", url), "", nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docId, channels := eventForTest(i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) assert.NoError(t, err) } - err := em.waitForProcessedTotal(base.TestCtx(t), 10, DefaultWaitForWebhook) + err := em.waitForProcessedTotal(ctx, 10, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(10), em.GetEventsProcessedSuccess()) @@ -480,7 +486,7 @@ func TestWebhookBasic(t *testing.T) { log.Println("Test filter function") wr.Clear() em = NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) filterFunction := `function(doc) { if (doc.value < 6) { return false; @@ -488,16 +494,16 @@ func TestWebhookBasic(t *testing.T) { return true; } }` - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docId, channels := eventForTest(i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) assert.NoError(t, err) } - err = em.waitForProcessedTotal(base.TestCtx(t), 10, DefaultWaitForWebhook) + err = em.waitForProcessedTotal(ctx, 10, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(4), em.GetEventsProcessedSuccess()) @@ -505,14 +511,14 @@ func TestWebhookBasic(t *testing.T) { log.Println("Test payload validation") wr.Clear() em = NewEventManager(terminator) - em.Start(0, -1) - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/echo", url), "", nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + em.Start(ctx, 0, -1) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/echo", url), "", nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) body, docId, channels := eventForTest(0) bodyBytes, _ := base.JSONMarshalCanonical(body) - err = em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err = em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) assert.NoError(t, err) - err = em.waitForProcessedTotal(base.TestCtx(t), 1, DefaultWaitForWebhook) + err = em.waitForProcessedTotal(ctx, 1, DefaultWaitForWebhook) assert.NoError(t, err) receivedPayload := string((wr.GetPayloads())[0]) fmt.Println("payload:", receivedPayload) @@ -555,17 +561,18 @@ func TestWebhookOverflows(t *testing.T) { log.Println("Test fast fill, fast webhook") wr.Clear() em := NewEventManager(terminator) - em.Start(5, -1) + ctx := base.TestCtx(t) + em.Start(ctx, 5, -1) timeout := uint64(60) - webhookHandler, _ := NewWebhook(fmt.Sprintf("%s/echo", url), "", &timeout, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ := NewWebhook(ctx, fmt.Sprintf("%s/echo", url), "", &timeout, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 100; i++ { body, docId, channels := eventForTest(i % 10) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) assert.NoError(t, err) } - err := em.waitForProcessedTotal(base.TestCtx(t), 100, DefaultWaitForWebhook) + err := em.waitForProcessedTotal(ctx, 100, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(100), em.GetEventsProcessedSuccess()) @@ -576,20 +583,20 @@ func TestWebhookOverflows(t *testing.T) { wr.Clear() errCount := 0 em = NewEventManager(terminator) - em.Start(5, 1) - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/slow", url), "", nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + em.Start(ctx, 5, 1) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/slow", url), "", nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 100; i++ { body, docId, channels := eventForTest(i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) if err != nil { errCount++ } } // Expect 21 to complete. 5 get goroutines immediately, 15 get queued, and one is blocked waiting // for a goroutine. The rest get discarded because the queue is full. - err = em.waitForProcessedTotal(base.TestCtx(t), 21, 10*time.Second) + err = em.waitForProcessedTotal(ctx, 21, 10*time.Second) assert.NoError(t, err) assert.Equal(t, int64(21), em.GetEventsProcessedSuccess()) assert.Equal(t, 79, errCount) @@ -599,16 +606,16 @@ func TestWebhookOverflows(t *testing.T) { log.Println("Test queue full, slow webhook, long wait") wr.Clear() em = NewEventManager(terminator) - em.Start(5, 1500) - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/slow", url), "", nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + em.Start(ctx, 5, 1500) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/slow", url), "", nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 100; i++ { body, docId, channels := eventForTest(i % 10) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) assert.NoError(t, err) } - err = em.waitForProcessedTotal(base.TestCtx(t), 100, 10*time.Second) + err = em.waitForProcessedTotal(ctx, 100, 10*time.Second) assert.NoError(t, err) assert.Equal(t, int64(100), em.GetEventsProcessedSuccess()) } @@ -645,20 +652,21 @@ func TestWebhookOldDoc(t *testing.T) { // Test basic webhook where an old doc is passed but not filtered log.Println("Test basic webhook where an old doc is passed but not filtered") em := NewEventManager(terminator) - em.Start(0, -1) - webhookHandler, _ := NewWebhook(fmt.Sprintf("%s/echo", url), "", nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + ctx := base.TestCtx(t) + em.Start(ctx, 0, -1) + webhookHandler, _ := NewWebhook(ctx, fmt.Sprintf("%s/echo", url), "", nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { oldBody, oldDocId, _ := eventForTest(strconv.Itoa(-i), i) oldBody[BodyId] = oldDocId oldBodyBytes, _ := base.JSONMarshal(oldBody) body, docId, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, string(oldBodyBytes), channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, string(oldBodyBytes), channels, false) assert.NoError(t, err) } - err := em.waitForProcessedTotal(base.TestCtx(t), 10, DefaultWaitForWebhook) + err := em.waitForProcessedTotal(ctx, 10, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(10), em.eventsProcessedSuccess) log.Printf("Actual: %v, Expected: %v", wr.GetCount(), 10) @@ -667,7 +675,7 @@ func TestWebhookOldDoc(t *testing.T) { log.Println("Test filter function with old doc which is not referenced") wr.Clear() em = NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) filterFunction := `function(doc) { if (doc.value < 6) { return false; @@ -675,18 +683,18 @@ func TestWebhookOldDoc(t *testing.T) { return true; } }` - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { oldBody, oldDocId, _ := eventForTest(strconv.Itoa(-i), i) oldBody[BodyId] = oldDocId oldBodyBytes, _ := base.JSONMarshal(oldBody) body, docId, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, string(oldBodyBytes), channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, string(oldBodyBytes), channels, false) assert.NoError(t, err) } - err = em.waitForProcessedTotal(base.TestCtx(t), 10, DefaultWaitForWebhook) + err = em.waitForProcessedTotal(ctx, 10, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(4), em.eventsProcessedSuccess) log.Printf("Actual: %v, Expected: %v", wr.GetCount(), 4) @@ -695,7 +703,7 @@ func TestWebhookOldDoc(t *testing.T) { log.Println("Test filter function with old doc") wr.Clear() em = NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) filterFunction = `function(doc, oldDoc) { if (doc.value < 6 && doc.value == -oldDoc.value) { return false; @@ -703,18 +711,18 @@ func TestWebhookOldDoc(t *testing.T) { return true; } }` - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { oldBody, oldDocId, _ := eventForTest(strconv.Itoa(-i), i) oldBody[BodyId] = oldDocId oldBodyBytes, _ := base.JSONMarshal(oldBody) body, docId, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, string(oldBodyBytes), channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, string(oldBodyBytes), channels, false) assert.NoError(t, err) } - err = em.waitForProcessedTotal(base.TestCtx(t), 10, DefaultWaitForWebhook) + err = em.waitForProcessedTotal(ctx, 10, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(4), em.eventsProcessedSuccess) log.Printf("Actual: %v, Expected: %v", wr.GetCount(), 4) @@ -723,7 +731,7 @@ func TestWebhookOldDoc(t *testing.T) { log.Println("Test filter function with old doc") wr.Clear() em = NewEventManager(terminator) - em.Start(0, -1) + em.Start(ctx, 0, -1) filterFunction = `function(doc, oldDoc) { if (oldDoc) { return true; @@ -731,12 +739,12 @@ func TestWebhookOldDoc(t *testing.T) { return false; } }` - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/echo", url), filterFunction, nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docId, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) assert.NoError(t, err) } for i := 10; i < 20; i++ { @@ -745,10 +753,10 @@ func TestWebhookOldDoc(t *testing.T) { oldBodyBytes, _ := base.JSONMarshal(oldBody) body, docId, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, string(oldBodyBytes), channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, string(oldBodyBytes), channels, false) assert.NoError(t, err) } - err = em.waitForProcessedTotal(base.TestCtx(t), 20, DefaultWaitForWebhook) + err = em.waitForProcessedTotal(ctx, 20, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(10), em.eventsProcessedSuccess) log.Printf("Actual: %v, Expected: %v", wr.GetCount(), 10) @@ -788,17 +796,18 @@ func TestWebhookTimeout(t *testing.T) { // Test fast execution, short timeout. All events processed log.Println("Test fast webhook, short timeout") em := NewEventManager(terminator) - em.Start(0, -1) + ctx := base.TestCtx(t) + em.Start(ctx, 0, -1) timeout := uint64(2) - webhookHandler, _ := NewWebhook(fmt.Sprintf("%s/echo", url), "", &timeout, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ := NewWebhook(ctx, fmt.Sprintf("%s/echo", url), "", &timeout, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docid, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) assert.NoError(t, err) } - err := em.waitForProcessedTotal(base.TestCtx(t), 10, DefaultWaitForWebhook) + err := em.waitForProcessedTotal(ctx, 10, DefaultWaitForWebhook) assert.NoError(t, err) assert.Equal(t, int64(10), em.eventsProcessedSuccess) @@ -810,21 +819,21 @@ func TestWebhookTimeout(t *testing.T) { wr.Clear() errCount := 0 em = NewEventManager(terminator) - em.Start(1, 1500) + em.Start(ctx, 1, 1500) timeout = uint64(1) - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/slow_2s", url), "", &timeout, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/slow_2s", url), "", &timeout, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docid, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) time.Sleep(2 * time.Millisecond) if err != nil { errCount++ } } // Even though we timed out waiting for response on the SG side, POST still completed on target side. - err = em.waitForProcessedTotal(base.TestCtx(t), 10, 30*time.Second) + err = em.waitForProcessedTotal(ctx, 10, 30*time.Second) assert.NoError(t, err) assert.Equal(t, int64(0), em.GetEventsProcessedSuccess()) assert.Equal(t, int64(10), em.GetEventsProcessedFail()) @@ -835,21 +844,21 @@ func TestWebhookTimeout(t *testing.T) { wr.Clear() errCount = 0 em = NewEventManager(terminator) - em.Start(1, 100) + em.Start(ctx, 1, 100) timeout = uint64(9) - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/slow_5s", url), "", &timeout, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/slow_5s", url), "", &timeout, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docid, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) time.Sleep(2 * time.Millisecond) if err != nil { errCount++ } } // wait for slow webhook to finish processing - err = em.waitForProcessedTotal(base.TestCtx(t), 5, 30*time.Second) + err = em.waitForProcessedTotal(ctx, 5, 30*time.Second) assert.NoError(t, err) assert.Equal(t, int64(5), em.GetEventsProcessedSuccess()) @@ -858,21 +867,21 @@ func TestWebhookTimeout(t *testing.T) { wr.Clear() errCount = 0 em = NewEventManager(terminator) - em.Start(1, 1500) + em.Start(ctx, 1, 1500) timeout = uint64(0) - webhookHandler, _ = NewWebhook(fmt.Sprintf("%s/slow", url), "", &timeout, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + webhookHandler, _ = NewWebhook(ctx, fmt.Sprintf("%s/slow", url), "", &timeout, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docid, channels := eventForTest(strconv.Itoa(i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docid, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docid, "", channels, false) time.Sleep(2 * time.Millisecond) if err != nil { errCount++ } } // wait for slow webhook to finish processing - err = em.waitForProcessedTotal(base.TestCtx(t), 10, 20*time.Second) + err = em.waitForProcessedTotal(ctx, 10, 20*time.Second) assert.NoError(t, err) assert.Equal(t, int64(10), em.eventsProcessedSuccess) @@ -907,13 +916,14 @@ func TestUnavailableWebhook(t *testing.T) { // Test unreachable webhook em := NewEventManager(terminator) - em.Start(0, -1) - webhookHandler, _ := NewWebhook("http://badhost:1000/echo", "", nil, nil) - em.RegisterEventHandler(webhookHandler, DocumentChange) + ctx := base.TestCtx(t) + em.Start(ctx, 0, -1) + webhookHandler, _ := NewWebhook(ctx, "http://badhost:1000/echo", "", nil, nil) + em.RegisterEventHandler(ctx, webhookHandler, DocumentChange) for i := 0; i < 10; i++ { body, docId, channels := eventForTest(strconv.Itoa(-i), i) bodyBytes, _ := base.JSONMarshal(body) - err := em.RaiseDocumentChangeEvent(bodyBytes, docId, "", channels, false) + err := em.RaiseDocumentChangeEvent(ctx, bodyBytes, docId, "", channels, false) assert.NoError(t, err) } time.Sleep(50 * time.Millisecond) @@ -972,7 +982,7 @@ func TestWebhookHandleUnsupportedEventType(t *testing.T) { defer ts.Close() wh := &Webhook{url: ts.URL} event := &UnsupportedEvent{} - success := wh.HandleEvent(event) + success := wh.HandleEvent(base.TestCtx(t), event) assert.False(t, success, "Event shouldn't get posted to webhook; event type is not supported") } @@ -983,8 +993,9 @@ func TestWebhookHandleEventDBStateChangeFilterFuncError(t *testing.T) { wh := &Webhook{url: ts.URL} event := mockDBStateChangeEvent("db", "online", "Index service is listening", "127.0.0.1:4985") source := `function (doc) { invalidKeyword if (doc.state == "online") { return true; } else { return false; } }` - wh.filter = NewJSEventFunction(source) - success := wh.HandleEvent(event) + ctx := base.TestCtx(t) + wh.filter = NewJSEventFunction(ctx, source) + success := wh.HandleEvent(ctx, event) assert.False(t, success, "Filter function processing should be aborted and warnings should be logged") } @@ -996,6 +1007,6 @@ func TestWebhookHandleEventDBStateChangeMarshalDocError(t *testing.T) { body := make(Body, 1) body["key"] = make(chan int) event := &DBStateChangeEvent{Doc: body} - success := wh.HandleEvent(event) + success := wh.HandleEvent(base.TestCtx(t), event) assert.False(t, success, "It should throw marshalling doc error and log warnings") } diff --git a/db/functions/function.go b/db/functions/function.go index dd7d752ec5..875b93f2ef 100644 --- a/db/functions/function.go +++ b/db/functions/function.go @@ -69,7 +69,7 @@ type functionImpl struct { } // Compiles the functions in a UserFunctionConfigMap, returning UserFunctions. -func CompileFunctions(config FunctionsConfig) (*db.UserFunctions, error) { +func CompileFunctions(ctx context.Context, config FunctionsConfig) (*db.UserFunctions, error) { if config.MaxFunctionCount != nil && len(config.Definitions) > *config.MaxFunctionCount { return nil, fmt.Errorf("too many functions declared (> %d)", *config.MaxFunctionCount) } @@ -81,7 +81,7 @@ func CompileFunctions(config FunctionsConfig) (*db.UserFunctions, error) { for name, fnConfig := range config.Definitions { if config.MaxCodeSize != nil && len(fnConfig.Code) > *config.MaxCodeSize { multiError = multiError.Append(fmt.Errorf("function code too large (> %d bytes)", *config.MaxCodeSize)) - } else if userFn, err := compileFunction(name, "function", fnConfig); err == nil { + } else if userFn, err := compileFunction(ctx, name, "function", fnConfig); err == nil { fns.Definitions[name] = userFn } else { multiError = multiError.Append(err) @@ -92,12 +92,12 @@ func CompileFunctions(config FunctionsConfig) (*db.UserFunctions, error) { // Validates a FunctionsConfig. func ValidateFunctions(ctx context.Context, config FunctionsConfig) error { - _, err := CompileFunctions(config) + _, err := CompileFunctions(ctx, config) return err } // Creates a functionImpl from a UserFunctionConfig. -func compileFunction(name string, typeName string, fnConfig *FunctionConfig) (*functionImpl, error) { +func compileFunction(ctx context.Context, name string, typeName string, fnConfig *FunctionConfig) (*functionImpl, error) { userFn := &functionImpl{ FunctionConfig: fnConfig, name: name, @@ -107,7 +107,7 @@ func compileFunction(name string, typeName string, fnConfig *FunctionConfig) (*f var err error switch fnConfig.Type { case "javascript": - userFn.compiled, err = newFunctionJSServer(name, typeName, fnConfig.Code) + userFn.compiled, err = newFunctionJSServer(ctx, name, typeName, fnConfig.Code) case "query": err = validateN1QLQuery(fnConfig.Code) if err != nil { diff --git a/db/functions/function_test.go b/db/functions/function_test.go index 1d36e48e9c..420e5bf844 100644 --- a/db/functions/function_test.go +++ b/db/functions/function_test.go @@ -435,7 +435,7 @@ func TestUserFunctionSyntaxError(t *testing.T) { }, } - _, err := CompileFunctions(kUserFunctionBadConfig) + _, err := CompileFunctions(base.TestCtx(t), kUserFunctionBadConfig) assert.Error(t, err) } @@ -456,7 +456,7 @@ func TestUserFunctionsMaxFunctionCount(t *testing.T) { }, }, } - _, err := CompileFunctions(twoFunctionConfig) + _, err := CompileFunctions(base.TestCtx(t), twoFunctionConfig) assert.ErrorContains(t, err, "too many functions declared (> 1)") } @@ -472,7 +472,7 @@ func TestUserFunctionsMaxCodeSize(t *testing.T) { }, }, } - _, err := CompileFunctions(functionConfig) + _, err := CompileFunctions(base.TestCtx(t), functionConfig) assert.ErrorContains(t, err, "function code too large (> 20 bytes)") } @@ -485,7 +485,7 @@ func TestUserFunctionAllow(t *testing.T) { db, ctx := setupTestDBWithFunctions(t, &kUserFunctionConfig, nil) defer db.Close(ctx) - authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions()) + authenticator := auth.NewAuthenticator(db.MetadataStore, db, db.AuthenticatorOptions(ctx)) user, err := authenticator.NewUser("maurice", "pass", base.SetOf("city-Paris")) assert.NoError(t, err) @@ -650,11 +650,11 @@ func setupTestDBWithFunctions(t *testing.T, fnConfig *FunctionsConfig, gqConfig } var err error if fnConfig != nil { - options.UserFunctions, err = CompileFunctions(*fnConfig) + options.UserFunctions, err = CompileFunctions(base.TestCtx(t), *fnConfig) assert.NoError(t, err) } if gqConfig != nil { - options.GraphQL, err = CompileGraphQL(gqConfig) + options.GraphQL, err = CompileGraphQL(base.TestCtx(t), gqConfig) assert.NoError(t, err) } return setupTestDBWithOptions(t, options) diff --git a/db/functions/graphql.go b/db/functions/graphql.go index 8d04f35ee3..c1154e66c2 100644 --- a/db/functions/graphql.go +++ b/db/functions/graphql.go @@ -92,8 +92,8 @@ type graphQLImpl struct { } // Creates a new GraphQL instance from its configuration. -func CompileGraphQL(config *GraphQLConfig) (*graphQLImpl, error) { - if schema, err := config.compileSchema(); err != nil { +func CompileGraphQL(ctx context.Context, config *GraphQLConfig) (*graphQLImpl, error) { + if schema, err := config.compileSchema(ctx); err != nil { return nil, err } else { gql := &graphQLImpl{ @@ -106,11 +106,11 @@ func CompileGraphQL(config *GraphQLConfig) (*graphQLImpl, error) { // Validates a GraphQL configuration by parsing the schema. func (config *GraphQLConfig) Validate(ctx context.Context) error { - _, err := config.compileSchema() + _, err := config.compileSchema(ctx) return err } -func (config *GraphQLConfig) compileSchema() (schema graphql.Schema, err error) { +func (config *GraphQLConfig) compileSchema(ctx context.Context) (schema graphql.Schema, err error) { // Get the schema source, from either `schema` or `schemaFile`: schemaSource, err := config.getSchema() if err != nil { @@ -137,11 +137,11 @@ ResolverLoop: } else if fieldName == "__typename" { // The "__typename" resolver returns the name of the concrete type of an // instance of an interface. - typeNameResolver, err = config.compileTypeNameResolver(typeName, fnConfig) + typeNameResolver, err = config.compileTypeNameResolver(ctx, typeName, fnConfig) resolverCount += 1 } else { var fn graphql.FieldResolveFn - fn, err = config.compileFieldResolver(typeName, fieldName, fnConfig) + fn, err = config.compileFieldResolver(ctx, typeName, fieldName, fnConfig) fieldMap[fieldName] = &gqltools.FieldResolve{Resolve: fn} resolverCount += 1 } @@ -177,7 +177,7 @@ ResolverLoop: }) if err == nil && len(schema.TypeMap()) == 0 { - base.WarnfCtx(context.Background(), "GraphQL Schema object has no registered TypeMap -- this probably means the schema has unresolved types. See gqltools warnings above") + base.WarnfCtx(ctx, "GraphQL Schema object has no registered TypeMap -- this probably means the schema has unresolved types. See gqltools warnings above") err = fmt.Errorf("GraphQL Schema object has no registered TypeMap -- this probably means the schema has unresolved types") } return schema, err @@ -219,14 +219,14 @@ func graphQLResolverName(typeName string, fieldName string) string { // Creates a graphQLResolver for the given JavaScript code, and returns a graphql-go FieldResolveFn // that invokes it. -func (config *GraphQLConfig) compileFieldResolver(typeName string, fieldName string, fnConfig FunctionConfig) (graphql.FieldResolveFn, error) { +func (config *GraphQLConfig) compileFieldResolver(ctx context.Context, typeName string, fieldName string, fnConfig FunctionConfig) (graphql.FieldResolveFn, error) { name := graphQLResolverName(typeName, fieldName) isMutation := typeName == "Mutation" if isMutation && fnConfig.Type == "query" { return nil, fmt.Errorf("GraphQL mutations must be implemented in JavaScript") } - userFn, err := compileFunction(name, "GraphQL resolver", &fnConfig) + userFn, err := compileFunction(ctx, name, "GraphQL resolver", &fnConfig) if err != nil { return nil, err } @@ -272,14 +272,14 @@ func resolverInfo(params graphql.ResolveParams) map[string]any { //////// TYPE-NAME RESOLVER: -func (config *GraphQLConfig) compileTypeNameResolver(interfaceName string, fnConfig FunctionConfig) (graphql.ResolveTypeFn, error) { +func (config *GraphQLConfig) compileTypeNameResolver(ctx context.Context, interfaceName string, fnConfig FunctionConfig) (graphql.ResolveTypeFn, error) { if fnConfig.Type != "javascript" { return nil, fmt.Errorf("a GraphQL '__typename__' resolver must be JavaScript") } else if fnConfig.Allow != nil { return nil, fmt.Errorf("'allow' is not valid in a GraphQL '__typename__' resolver") } - fn, err := compileFunction(interfaceName, "GraphQL type-name resolver", &fnConfig) + fn, err := compileFunction(ctx, interfaceName, "GraphQL type-name resolver", &fnConfig) if err != nil { return nil, err } diff --git a/db/functions/graphql_test.go b/db/functions/graphql_test.go index 9d4129ac56..77a5f865d4 100644 --- a/db/functions/graphql_test.go +++ b/db/functions/graphql_test.go @@ -481,7 +481,7 @@ func TestGraphQLMaxSchemaSize(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, "GraphQL schema too large (> 20 bytes)") } @@ -513,7 +513,7 @@ func TestGraphQLMaxResolverCount(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, "too many GraphQL resolvers (> 1)") } @@ -530,7 +530,7 @@ func TestArgsInResolverConfig(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, `'args' is not valid in a GraphQL resolver config`) } @@ -539,7 +539,7 @@ func TestUnresolvedTypesInSchema(t *testing.T) { Schema: base.StringPtr(`type Query{} type abc{def:kkk}`), Resolvers: nil, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, `GraphQL Schema object has no registered TypeMap -- this probably means the schema has unresolved types`) } @@ -556,7 +556,7 @@ func TestInvalidMutationType(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, `unrecognized 'type' "cpp"`) }) t.Run("Unrecognized type query", func(t *testing.T) { @@ -571,7 +571,7 @@ func TestInvalidMutationType(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, `GraphQL mutations must be implemented in JavaScript`) }) } @@ -588,7 +588,7 @@ func TestCompilationErrorInResolverCode(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, `500 Error compiling GraphQL resolver "Query:square"`) } @@ -606,7 +606,7 @@ func TestGraphQLMaxCodeSize(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, "resolver square code too large (> 2 bytes)") } @@ -630,7 +630,7 @@ func TestTypenameResolver(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, "a GraphQL '__typename__' resolver must be JavaScript") }) t.Run("Error in compiling typename resolver", func(t *testing.T) { @@ -649,7 +649,7 @@ func TestTypenameResolver(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, `Error compiling GraphQL type-name resolver "Book"`) }) t.Run("Typename Resolver should not have allow", func(t *testing.T) { @@ -671,7 +671,7 @@ func TestTypenameResolver(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, "'allow' is not valid in a GraphQL '__typename__' resolver") }) @@ -699,7 +699,7 @@ func TestTypenameResolver(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.NoError(t, err) db, ctx := setupTestDBWithFunctions(t, nil, &config) defer db.Close(ctx) @@ -716,7 +716,7 @@ func TestInvalidSchemaAndSchemaFile(t *testing.T) { SchemaFile: base.StringPtr("someInvalidPath/someInvalidFileName"), Resolvers: nil, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, "GraphQL config: only one of `schema` and `schemaFile` may be used") }) @@ -724,7 +724,7 @@ func TestInvalidSchemaAndSchemaFile(t *testing.T) { var config = GraphQLConfig{ Resolvers: nil, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, "GraphQL config: either `schema` or `schemaFile` must be defined") }) @@ -732,7 +732,7 @@ func TestInvalidSchemaAndSchemaFile(t *testing.T) { var config = GraphQLConfig{ SchemaFile: base.StringPtr("dummySchemaFile.txt"), } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) fmt.Println(err) assert.ErrorContains(t, err, "can't read file") }) @@ -747,7 +747,7 @@ func TestValidSchemaFile(t *testing.T) { var config = GraphQLConfig{ SchemaFile: base.StringPtr("schema.graphql"), } - _, err = CompileGraphQL(&config) + _, err = CompileGraphQL(base.TestCtx(t), &config) assert.NoError(t, err) err = os.Remove("schema.graphql") @@ -774,6 +774,6 @@ func TestFixOfCVE_2022_37315(t *testing.T) { }, }, } - _, err := CompileGraphQL(&config) + _, err := CompileGraphQL(base.TestCtx(t), &config) assert.ErrorContains(t, err, `Syntax Error GraphQL (1:1) Unexpected Name "String"`) } diff --git a/db/functions/js_runner.go b/db/functions/js_runner.go index d0a66a8e5b..d38077ac39 100644 --- a/db/functions/js_runner.go +++ b/db/functions/js_runner.go @@ -42,11 +42,11 @@ var kUserFunctionMaxCallDepth = 20 var kJavaScriptWrapper string // Creates a JSServer instance wrapping a userJSRunner, for user JS functions and GraphQL resolvers. -func newFunctionJSServer(name string, what string, sourceCode string) (*sgbucket.JSServer, error) { +func newFunctionJSServer(ctx context.Context, name string, what string, sourceCode string) (*sgbucket.JSServer, error) { js := fmt.Sprintf(kJavaScriptWrapper, sourceCode) jsServer := sgbucket.NewJSServer(js, 0, kUserFunctionCacheSize, func(fnSource string, timeout time.Duration) (sgbucket.JSServerTask, error) { - return newJSRunner(name, what, fnSource) + return newJSRunner(ctx, name, what, fnSource) }) // Call WithTask to force a task to be instantiated, which will detect syntax errors in the script. Otherwise the error only gets detected the first time a client calls the function. var err error @@ -67,8 +67,7 @@ type jsRunner struct { } // Creates a jsRunner given its name and JavaScript source code. -func newJSRunner(name string, kind string, funcSource string) (*jsRunner, error) { - ctx := context.Background() +func newJSRunner(ctx context.Context, name string, kind string, funcSource string) (*jsRunner, error) { runner := &jsRunner{ name: name, kind: kind, @@ -104,7 +103,7 @@ func newJSRunner(name string, kind string, funcSource string) (*jsRunner, error) runner.currentDB = nil }() if err != nil { - base.ErrorfCtx(context.Background(), base.KeyJavascript.String()+": %s %s failed: %#v", runner.kind, runner.name, err) + base.ErrorfCtx(ctx, base.KeyJavascript.String()+": %s %s failed: %#v", runner.kind, runner.name, err) return nil, runner.convertError(err) } return jsResult.Export() diff --git a/db/functions/main_test.go b/db/functions/main_test.go index 0b82743a8b..c35c222c81 100644 --- a/db/functions/main_test.go +++ b/db/functions/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package functions import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,6 +19,7 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/db/functions/n1ql_function_test.go b/db/functions/n1ql_function_test.go index f75f2dc3a9..5b50323f20 100644 --- a/db/functions/n1ql_function_test.go +++ b/db/functions/n1ql_function_test.go @@ -232,7 +232,7 @@ func TestUserN1QLQueriesInvalid(t *testing.T) { }, } - _, err := CompileFunctions(kBadN1QLFunctionsConfig) + _, err := CompileFunctions(base.TestCtx(t), kBadN1QLFunctionsConfig) assert.ErrorContains(t, err, "only SELECT queries are allowed") // See fn validateN1QLQuery var kOKN1QLFunctionsConfig = FunctionsConfig{ @@ -252,6 +252,6 @@ func TestUserN1QLQueriesInvalid(t *testing.T) { }, } - _, err = CompileFunctions(kOKN1QLFunctionsConfig) + _, err = CompileFunctions(base.TestCtx(t), kOKN1QLFunctionsConfig) assert.NoError(t, err) } diff --git a/db/import.go b/db/import.go index c6a7c84300..a749dd573d 100644 --- a/db/import.go +++ b/db/import.go @@ -91,7 +91,7 @@ func (db *DatabaseCollectionWithUser) ImportDoc(ctx context.Context, docid strin return nil, err } - return db.importDoc(ctx, docid, existingDoc.Body(), expiry, isDelete, existingBucketDoc, mode) + return db.importDoc(ctx, docid, existingDoc.Body(ctx), expiry, isDelete, existingBucketDoc, mode) } // Import document @@ -150,7 +150,7 @@ func (db *DatabaseCollectionWithUser) importDoc(ctx context.Context, docid strin // If this is an on-demand import, we want to continue to import the current version of the doc. Re-initialize existing doc based on the latest doc if mode == ImportOnDemand { - body = doc.Body() + body = doc.Body(ctx) if body == nil { return nil, nil, false, nil, base.ErrEmptyDocument } @@ -172,7 +172,7 @@ func (db *DatabaseCollectionWithUser) importDoc(ctx context.Context, docid strin if doc.inlineSyncData { existingDoc.Body, err = doc.MarshalBodyAndSync() } else { - existingDoc.Body, err = doc.BodyBytes() + existingDoc.Body, err = doc.BodyBytes(ctx) } if err != nil { @@ -196,7 +196,7 @@ func (db *DatabaseCollectionWithUser) importDoc(ctx context.Context, docid strin // If document still requires import post-migration attempt, continue with import processing based on the body returned by migrate doc = migratedDoc - body = migratedDoc.Body() + body = migratedDoc.Body(ctx) base.InfofCtx(ctx, base.KeyMigrate, "Falling back to import with cas: %v", doc.Cas) } @@ -276,7 +276,7 @@ func (db *DatabaseCollectionWithUser) importDoc(ctx context.Context, docid strin if shouldGenerateNewRev { // The active rev is the parent for an import parentRev := doc.CurrentRev - generation, _ := ParseRevID(parentRev) + generation, _ := ParseRevID(ctx, parentRev) generation++ newRev = CreateRevIDWithBytes(generation, parentRev, rawBodyForRevID) if err != nil { @@ -373,7 +373,7 @@ func (db *DatabaseCollectionWithUser) migrateMetadata(ctx context.Context, docid } // Move any large revision bodies to external storage - err = doc.migrateRevisionBodies(db.dataStore) + err = doc.migrateRevisionBodies(ctx, db.dataStore) if err != nil { base.InfofCtx(ctx, base.KeyMigrate, "Error migrating revision bodies to external storage, doc %q, (cas=%d), Error: %v", base.UD(docid), doc.Cas, err) } @@ -476,7 +476,7 @@ type ImportFilterFunction struct { func NewImportFilterFunction(ctx context.Context, fnSource string, timeout time.Duration) *ImportFilterFunction { - base.DebugfCtx(context.Background(), base.KeyImport, "Creating new ImportFilterFunction") + base.DebugfCtx(ctx, base.KeyImport, "Creating new ImportFilterFunction") return &ImportFilterFunction{ JSServer: sgbucket.NewJSServer(fnSource, timeout, kTaskCacheSize, func(fnSource string, timeout time.Duration) (sgbucket.JSServerTask, error) { diff --git a/db/import_test.go b/db/import_test.go index ce02d0369c..70b6a803de 100644 --- a/db/import_test.go +++ b/db/import_test.go @@ -60,7 +60,7 @@ func TestMigrateMetadata(t *testing.T) { assert.NoError(t, err, "Error writing doc w/ expiry") // Get the existing bucket doc - _, existingBucketDoc, err := collection.GetDocWithXattr(key, DocUnmarshalAll) + _, existingBucketDoc, err := collection.GetDocWithXattr(ctx, key, DocUnmarshalAll) require.NoError(t, err) // Set the expiry value to a stale value (it's about to be stale, since below it will get updated to a later value) existingBucketDoc.Expiry = uint32(syncMetaExpiry.Unix()) @@ -153,7 +153,7 @@ func TestImportWithStaleBucketDocCorrectExpiry(t *testing.T) { assert.NoError(t, err, "Error writing doc w/ expiry") // Get the existing bucket doc - _, existingBucketDoc, err := collection.GetDocWithXattr(key, DocUnmarshalAll) + _, existingBucketDoc, err := collection.GetDocWithXattr(ctx, key, DocUnmarshalAll) assert.NoError(t, err, fmt.Sprintf("Error retrieving doc w/ xattr: %v", err)) body = Body{} @@ -322,7 +322,7 @@ func TestImportWithCasFailureUpdate(t *testing.T) { assert.NoError(t, err) // Get the existing bucket doc - _, existingBucketDoc, err = collection.GetDocWithXattr(testcase.docname, DocUnmarshalAll) + _, existingBucketDoc, err = collection.GetDocWithXattr(ctx, testcase.docname, DocUnmarshalAll) assert.NoError(t, err, fmt.Sprintf("Error retrieving doc w/ xattr: %v", err)) importD := `{"new":"Val"}` @@ -429,7 +429,7 @@ func assertXattrSyncMetaRevGeneration(t *testing.T, dataStore base.DataStore, ke assert.NoError(t, err, "Error Getting Xattr") revision, ok := xattr["rev"] assert.True(t, ok) - generation, _ := ParseRevID(revision.(string)) + generation, _ := ParseRevID(base.TestCtx(t), revision.(string)) log.Printf("assertXattrSyncMetaRevGeneration generation: %d rev: %s", generation, revision) assert.True(t, generation == expectedRevGeneration) } @@ -544,15 +544,13 @@ func TestImportNonZeroStart(t *testing.T) { defer db.Close(ctx) collection := GetSingleDatabaseCollectionWithUser(t, db) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return collection.collectionStats.ImportCount.Value() }, 1) - require.True(t, ok) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return db.DbStats.Database().DCPReceivedCount.Value() }, 1) - require.True(t, ok) doc, err := collection.GetDocument(base.TestCtx(t), doc1, DocUnmarshalAll) require.NoError(t, err) @@ -577,9 +575,8 @@ func TestImportInvalidMetadata(t *testing.T) { _, err := bucket.GetSingleDataStore().Add("doc1", 0, `{"foo" : "bar", "_sync" : 1 }`) require.NoError(t, err) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return db.DbStats.SharedBucketImport().ImportErrorCount.Value() }, 1) - require.True(t, ok) require.Equal(t, int64(0), db.DbStats.SharedBucketImport().ImportCount.Value()) } diff --git a/db/indexes.go b/db/indexes.go index bd683e6398..337201b0b6 100644 --- a/db/indexes.go +++ b/db/indexes.go @@ -337,7 +337,7 @@ func (i *SGIndex) createIfNeeded(ctx context.Context, bucket base.N1QLStore, opt } description := fmt.Sprintf("Attempt to create index %s", indexName) - err, _ = base.RetryLoop(description, worker, sleeper) + err, _ = base.RetryLoop(ctx, description, worker, sleeper) if err != nil { return false, pkgerrors.Wrapf(err, "Error installing Couchbase index: %v", indexName) diff --git a/db/indextest/main_test.go b/db/indextest/main_test.go index 0c16d395eb..dde193a570 100644 --- a/db/indextest/main_test.go +++ b/db/indextest/main_test.go @@ -25,8 +25,9 @@ func TestMain(m *testing.M) { if base.UnitTestUrlIsWalrus() || base.TestsDisableGSI() { return } + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - base.TestBucketPoolMain(m, primaryIndexReadier, primaryIndexInit, tbpOptions) + base.TestBucketPoolMain(ctx, m, primaryIndexReadier, primaryIndexInit, tbpOptions) } // primaryIndexInit is run synchronously only once per-bucket to create a primary index. diff --git a/db/main_test.go b/db/main_test.go index 7b764382e2..9f7157e6d0 100644 --- a/db/main_test.go +++ b/db/main_test.go @@ -11,12 +11,14 @@ licenses/APL2.txt. package db import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - TestBucketPoolWithIndexes(m, tbpOptions) + TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/db/repair_bucket.go b/db/repair_bucket.go index 2ca3bf0c0a..2bfae2208c 100644 --- a/db/repair_bucket.go +++ b/db/repair_bucket.go @@ -50,7 +50,7 @@ type RepairBucketResult struct { // Given a Couchbase Bucket doc, transform the doc in some way to produce a new doc. // Also return a boolean to indicate whether a transformation took place, or any errors occurred. -type DocTransformer func(docId string, originalCBDoc []byte) (transformedCBDoc []byte, transformed bool, err error) +type DocTransformer func(ctx context.Context, docId string, originalCBDoc []byte) (transformedCBDoc []byte, transformed bool, err error) // A RepairBucket struct is the main API entrypoint to call for repairing documents in buckets type RepairBucket struct { @@ -297,12 +297,12 @@ func (r RepairBucket) WriteRepairedDocsToBucket(docId string, originalDoc, updat } // Loops over all repair jobs and applies them -func (r RepairBucket) TransformBucketDoc(docId string, originalCBDoc []byte) (transformedCBDoc []byte, transformed bool, repairJobs []RepairJobType, err error) { +func (r RepairBucket) TransformBucketDoc(ctx context.Context, docId string, originalCBDoc []byte) (transformedCBDoc []byte, transformed bool, repairJobs []RepairJobType, err error) { transformed = false for _, repairJob := range r.RepairJobs { - repairedDoc, repairedDocTxformed, repairDocErr := repairJob(docId, originalCBDoc) + repairedDoc, repairedDocTxformed, repairDocErr := repairJob(ctx, docId, originalCBDoc) if repairDocErr != nil { return nil, false, repairJobs, repairDocErr } @@ -332,10 +332,10 @@ func (r RepairBucket) TransformBucketDoc(docId string, originalCBDoc []byte) (tr } // Repairs rev tree cycles (see SG issue #2847) -func RepairJobRevTreeCycles(docId string, originalCBDoc []byte) (transformedCBDoc []byte, transformed bool, err error) { +func RepairJobRevTreeCycles(ctx context.Context, docId string, originalCBDoc []byte) (transformedCBDoc []byte, transformed bool, err error) { - base.DebugfCtx(context.TODO(), base.KeyCRUD, "RepairJobRevTreeCycles() called with doc id: %v", base.UD(docId)) - defer base.DebugfCtx(context.TODO(), base.KeyCRUD, "RepairJobRevTreeCycles() finished. Doc id: %v. transformed: %v. err: %v", base.UD(docId), base.UD(transformed), err) + base.DebugfCtx(ctx, base.KeyCRUD, "RepairJobRevTreeCycles() called with doc id: %v", base.UD(docId)) + defer base.DebugfCtx(ctx, base.KeyCRUD, "RepairJobRevTreeCycles() finished. Doc id: %v. transformed: %v. err: %v", base.UD(docId), base.UD(transformed), err) doc, errUnmarshal := unmarshalDocument(docId, originalCBDoc) if errUnmarshal != nil { @@ -351,7 +351,7 @@ func RepairJobRevTreeCycles(docId string, originalCBDoc []byte) (transformedCBDo } // Repair it - if err := doc.History.RepairCycles(); err != nil { + if err := doc.History.RepairCycles(ctx); err != nil { return nil, false, err } diff --git a/db/revision.go b/db/revision.go index 12b6dc6bea..90c783f305 100644 --- a/db/revision.go +++ b/db/revision.go @@ -73,16 +73,16 @@ func (b *Body) Unmarshal(data []byte) error { return nil } -func (body Body) Copy(copyType BodyCopyType) Body { +func (body Body) Copy(ctx context.Context, copyType BodyCopyType) Body { switch copyType { case BodyShallowCopy: return body.ShallowCopy() case BodyDeepCopy: - return body.DeepCopy() + return body.DeepCopy(ctx) case BodyNoCopy: return body default: - base.InfofCtx(context.Background(), base.KeyCRUD, "Unexpected copy type specified in body.Copy - defaulting to shallow copy. copyType: %d", copyType) + base.InfofCtx(ctx, base.KeyCRUD, "Unexpected copy type specified in body.Copy - defaulting to shallow copy. copyType: %d", copyType) return body.ShallowCopy() } } @@ -98,11 +98,11 @@ func (body Body) ShallowCopy() Body { return copied } -func (body Body) DeepCopy() Body { +func (body Body) DeepCopy(ctx context.Context) Body { var copiedBody Body err := base.DeepCopyInefficient(&copiedBody, body) if err != nil { - base.InfofCtx(context.Background(), base.KeyCRUD, "Error copying body: %v", err) + base.InfofCtx(ctx, base.KeyCRUD, "Error copying body: %v", err) } return copiedBody } @@ -360,37 +360,37 @@ func CreateRevIDWithBytes(generation int, parentRevID string, bodyBytes []byte) } // Returns the generation number (numeric prefix) of a revision ID. -func genOfRevID(revid string) int { +func genOfRevID(ctx context.Context, revid string) int { if revid == "" { return 0 } var generation int n, _ := fmt.Sscanf(revid, "%d-", &generation) if n < 1 || generation < 1 { - base.WarnfCtx(context.Background(), "genOfRevID unsuccessful for %q", revid) + base.WarnfCtx(ctx, "genOfRevID unsuccessful for %q", revid) return -1 } return generation } // Splits a revision ID into generation number and hex digest. -func ParseRevID(revid string) (int, string) { +func ParseRevID(ctx context.Context, revid string) (int, string) { if revid == "" { return 0, "" } idx := strings.Index(revid, "-") if idx == -1 { - base.WarnfCtx(context.Background(), "parseRevID found no separator in rev %q", revid) + base.WarnfCtx(ctx, "parseRevID found no separator in rev %q", revid) return -1, "" } gen, err := strconv.Atoi(revid[:idx]) if err != nil { - base.WarnfCtx(context.Background(), "parseRevID unexpected generation in rev %q: %s", revid, err) + base.WarnfCtx(ctx, "parseRevID unexpected generation in rev %q: %s", revid, err) return -1, "" } else if gen < 1 { - base.WarnfCtx(context.Background(), "parseRevID unexpected generation in rev %q", revid) + base.WarnfCtx(ctx, "parseRevID unexpected generation in rev %q", revid) return -1, "" } @@ -401,9 +401,9 @@ func ParseRevID(revid string) (int, string) { // 1 if id1 is 'greater' than id2 // -1 if id1 is 'less' than id2 // 0 if the two are equal. -func compareRevIDs(id1, id2 string) int { - gen1, sha1 := ParseRevID(id1) - gen2, sha2 := ParseRevID(id2) +func compareRevIDs(ctx context.Context, id1, id2 string) int { + gen1, sha1 := ParseRevID(ctx, id1) + gen2, sha2 := ParseRevID(ctx, id2) switch { case gen1 > gen2: return 1 diff --git a/db/revision_cache_interface.go b/db/revision_cache_interface.go index abdc2faa9e..cd8ba32b39 100644 --- a/db/revision_cache_interface.go +++ b/db/revision_cache_interface.go @@ -154,7 +154,7 @@ func (rev *DocumentRevision) Body() (b Body, err error) { // Mutable1xBody returns a copy of the given document revision as a 1.x style body (with special properties) // Callers are free to modify this body without affecting the document revision. -func (rev *DocumentRevision) Mutable1xBody(db *DatabaseCollectionWithUser, requestedHistory Revisions, attachmentsSince []string, showExp bool) (b Body, err error) { +func (rev *DocumentRevision) Mutable1xBody(ctx context.Context, db *DatabaseCollectionWithUser, requestedHistory Revisions, attachmentsSince []string, showExp bool) (b Body, err error) { b, err = rev.Body() if err != nil { return nil, err @@ -186,7 +186,7 @@ func (rev *DocumentRevision) Mutable1xBody(db *DatabaseCollectionWithUser, reque if len(attachmentsSince) > 0 { ancestor := rev.History.findAncestor(attachmentsSince) if ancestor != "" { - minRevpos, _ = ParseRevID(ancestor) + minRevpos, _ = ParseRevID(ctx, ancestor) minRevpos++ } } @@ -207,9 +207,9 @@ func (rev *DocumentRevision) Mutable1xBody(db *DatabaseCollectionWithUser, reque } // As1xBytes returns a byte slice representing the 1.x style body, containing special properties (i.e. _id, _rev, _attachments, etc.) -func (rev *DocumentRevision) As1xBytes(db *DatabaseCollectionWithUser, requestedHistory Revisions, attachmentsSince []string, showExp bool) (b []byte, err error) { +func (rev *DocumentRevision) As1xBytes(ctx context.Context, db *DatabaseCollectionWithUser, requestedHistory Revisions, attachmentsSince []string, showExp bool) (b []byte, err error) { // unmarshal - body1x, err := rev.Mutable1xBody(db, requestedHistory, attachmentsSince, showExp) + body1x, err := rev.Mutable1xBody(ctx, db, requestedHistory, attachmentsSince, showExp) if err != nil { return nil, err } @@ -264,7 +264,7 @@ func revCacheLoaderForDocument(ctx context.Context, backingStore RevisionCacheBa if bodyBytes, body, attachments, err = backingStore.getRevision(ctx, doc, revid); err != nil { // If we can't find the revision (either as active or conflicted body from the document, or as old revision body backup), check whether // the revision was a channel removal. If so, we want to store as removal in the revision cache - removalBodyBytes, removalHistory, activeChannels, isRemoval, isDelete, isRemovalErr := doc.IsChannelRemoval(revid) + removalBodyBytes, removalHistory, activeChannels, isRemoval, isDelete, isRemovalErr := doc.IsChannelRemoval(ctx, revid) if isRemovalErr != nil { return bodyBytes, body, history, channels, isRemoval, nil, isDelete, nil, isRemovalErr } @@ -282,7 +282,7 @@ func revCacheLoaderForDocument(ctx context.Context, backingStore RevisionCacheBa if getHistoryErr != nil { return bodyBytes, body, history, channels, removed, nil, deleted, nil, getHistoryErr } - history = encodeRevisions(doc.ID, validatedHistory) + history = encodeRevisions(ctx, doc.ID, validatedHistory) channels = doc.History[revid].Channels return bodyBytes, body, history, channels, removed, attachments, deleted, doc.Expiry, err diff --git a/db/revision_cache_lru.go b/db/revision_cache_lru.go index 9a356bdc17..575c7c6811 100644 --- a/db/revision_cache_lru.go +++ b/db/revision_cache_lru.go @@ -423,7 +423,7 @@ func (value *revCacheValue) loadForDoc(ctx context.Context, backingStore Revisio docRev, err = value.asDocumentRevision(docRevBody, nil) // If the body is requested and not yet populated on revCacheValue, populate it from the doc if includeBody && docRev._shallowCopyBody == nil { - body := doc.Body() + body := doc.Body(ctx) value.lock.Lock() if value.body == nil { value.body = body diff --git a/db/revision_test.go b/db/revision_test.go index fc9faaaff4..683e477a4d 100644 --- a/db/revision_test.go +++ b/db/revision_test.go @@ -25,21 +25,22 @@ func TestParseRevID(t *testing.T) { var generation int var digest string - generation, _ = ParseRevID("ljlkjl") + ctx := base.TestCtx(t) + generation, _ = ParseRevID(ctx, "ljlkjl") log.Printf("generation: %v", generation) assert.True(t, generation == -1, "Expected -1 generation for invalid rev id") - generation, digest = ParseRevID("1-ljlkjl") + generation, digest = ParseRevID(ctx, "1-ljlkjl") log.Printf("generation: %v, digest: %v", generation, digest) assert.True(t, generation == 1, "Expected 1 generation") assert.True(t, digest == "ljlkjl", "Unexpected digest") - generation, digest = ParseRevID("2222-") + generation, digest = ParseRevID(ctx, "2222-") log.Printf("generation: %v, digest: %v", generation, digest) assert.True(t, generation == 2222, "Expected invalid generation") assert.True(t, digest == "", "Unexpected digest") - generation, digest = ParseRevID("333-a") + generation, digest = ParseRevID(ctx, "333-a") log.Printf("generation: %v, digest: %v", generation, digest) assert.True(t, generation == 333, "Expected generation") assert.True(t, digest == "a", "Unexpected digest") @@ -211,7 +212,7 @@ func BenchmarkSpecialProperties(b *testing.B) { "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, } - specialBody := noSpecialBody.Copy(BodyShallowCopy) + specialBody := noSpecialBody.Copy(base.TestCtx(b), BodyShallowCopy) specialBody[BodyId] = "abc123" specialBody[BodyRev] = "1-abc" diff --git a/db/revtree.go b/db/revtree.go index 0e658c7096..08c2dc11bf 100644 --- a/db/revtree.go +++ b/db/revtree.go @@ -187,7 +187,7 @@ func (tree RevTree) ContainsCycles() bool { } // Repair rev trees that have cycles introduced by SG Issue #2847 -func (tree RevTree) RepairCycles() (err error) { +func (tree RevTree) RepairCycles(ctx context.Context) (err error) { // This function will be called back for every leaf node in tree leafProcessor := func(leaf *RevInfo) { @@ -200,8 +200,8 @@ func (tree RevTree) RepairCycles() (err error) { for { - if node.ParentGenGTENodeGen() { - base.InfofCtx(context.Background(), base.KeyCRUD, "Node %+v detected to have invalid parent rev (parent generation larger than node generation). Repairing by designating as a root node.", base.UD(node)) + if node.ParentGenGTENodeGen(ctx) { + base.InfofCtx(ctx, base.KeyCRUD, "Node %+v detected to have invalid parent rev (parent generation larger than node generation). Repairing by designating as a root node.", base.UD(node)) node.Parent = "" break } @@ -231,8 +231,8 @@ func (tree RevTree) RepairCycles() (err error) { // // where the parent generation is *higher* than the node generation, which is never a valid scenario. // Likewise, detect situations where the parent generation is equal to the node generation, which is also invalid. -func (node RevInfo) ParentGenGTENodeGen() bool { - return genOfRevID(node.Parent) >= genOfRevID(node.ID) +func (node RevInfo) ParentGenGTENodeGen(ctx context.Context) bool { + return genOfRevID(ctx, node.Parent) >= genOfRevID(ctx, node.ID) } // Returns true if the RevTree has an entry for this revid. @@ -313,7 +313,7 @@ func (tree RevTree) isLeaf(revid string) bool { // Finds the "winning" revision, the one that should be treated as the default. // This is the leaf revision whose (!deleted, generation, hash) tuple compares the highest. -func (tree RevTree) winningRevision() (winner string, branched bool, inConflict bool) { +func (tree RevTree) winningRevision(ctx context.Context) (winner string, branched bool, inConflict bool) { winnerExists := false leafCount := 0 activeLeafCount := 0 @@ -324,7 +324,7 @@ func (tree RevTree) winningRevision() (winner string, branched bool, inConflict activeLeafCount++ } if (exists && !winnerExists) || - ((exists == winnerExists) && compareRevIDs(info.ID, winner) > 0) { + ((exists == winnerExists) && compareRevIDs(ctx, info.ID, winner) > 0) { winner = info.ID winnerExists = exists } @@ -423,10 +423,10 @@ func (tree RevTree) setRevisionBody(revid string, body []byte, bodyKey string, h info.HasAttachments = hasAttachments } -func (tree RevTree) removeRevisionBody(revid string) (deletedBodyKey string) { +func (tree RevTree) removeRevisionBody(ctx context.Context, revid string) (deletedBodyKey string) { info, found := tree[revid] if !found { - base.ErrorfCtx(context.Background(), "RemoveRevisionBody called for revid not in tree: %v", revid) + base.ErrorfCtx(ctx, "RemoveRevisionBody called for revid not in tree: %v", revid) return "" } deletedBodyKey = info.BodyKey @@ -461,7 +461,7 @@ func (tree RevTree) copy() RevTree { // // pruned: number of revisions pruned // prunedTombstoneBodyKeys: set of tombstones with external body storage that were pruned, as map[revid]bodyKey -func (tree RevTree) pruneRevisions(maxDepth uint32, keepRev string) (pruned int, prunedTombstoneBodyKeys map[string]string) { +func (tree RevTree) pruneRevisions(ctx context.Context, maxDepth uint32, keepRev string) (pruned int, prunedTombstoneBodyKeys map[string]string) { if len(tree) <= int(maxDepth) { return @@ -479,7 +479,7 @@ func (tree RevTree) pruneRevisions(maxDepth uint32, keepRev string) (pruned int, } // Calculate tombstoneGenerationThreshold - genShortestNonTSBranch, foundShortestNonTSBranch := tree.FindShortestNonTombstonedBranch() + genShortestNonTSBranch, foundShortestNonTSBranch := tree.FindShortestNonTombstonedBranch(ctx) tombstoneGenerationThreshold := -1 if foundShortestNonTSBranch { // Only set the tombstoneGenerationThreshold if a genShortestNonTSBranch was found. (fixes #2695) @@ -493,7 +493,7 @@ func (tree RevTree) pruneRevisions(maxDepth uint32, keepRev string) (pruned int, if !leaf.Deleted { // Ignore non-tombstoned leaves continue } - leafGeneration, _ := ParseRevID(leaf.ID) + leafGeneration, _ := ParseRevID(ctx, leaf.ID) if leafGeneration < tombstoneGenerationThreshold { pruned += tree.DeleteBranch(leaf) if leaf.BodyKey != "" { @@ -567,11 +567,11 @@ func (tree RevTree) computeDepthsAndFindLeaves() (maxDepth uint32, leaves []stri // http://cbmobile-bucket.s3.amazonaws.com/diagrams/example-sync-gateway-revtrees/three_branches.png // // The minimim generation that has a non-deleted leaf is "7-non-winning unresolved" -func (tree RevTree) FindShortestNonTombstonedBranch() (generation int, found bool) { - return tree.FindShortestNonTombstonedBranchFromLeaves(tree.GetLeaves()) +func (tree RevTree) FindShortestNonTombstonedBranch(ctx context.Context) (generation int, found bool) { + return tree.FindShortestNonTombstonedBranchFromLeaves(ctx, tree.GetLeaves()) } -func (tree RevTree) FindShortestNonTombstonedBranchFromLeaves(leaves []string) (generation int, found bool) { +func (tree RevTree) FindShortestNonTombstonedBranchFromLeaves(ctx context.Context, leaves []string) (generation int, found bool) { found = false genShortestNonTSBranch := math.MaxInt32 @@ -583,7 +583,7 @@ func (tree RevTree) FindShortestNonTombstonedBranchFromLeaves(leaves []string) ( // This is a tombstoned branch, skip it continue } - gen := genOfRevID(revid) + gen := genOfRevID(ctx, revid) if gen > 0 && gen < genShortestNonTSBranch { genShortestNonTSBranch = gen found = true @@ -597,14 +597,14 @@ func (tree RevTree) FindShortestNonTombstonedBranchFromLeaves(leaves []string) ( // http://cbmobile-bucket.s3.amazonaws.com/diagrams/example-sync-gateway-revtrees/four_branches_two_tombstoned.png // // The longest deleted branch has a generation of 10 -func (tree RevTree) FindLongestTombstonedBranch() (generation int) { - return tree.FindLongestTombstonedBranchFromLeaves(tree.GetLeaves()) +func (tree RevTree) FindLongestTombstonedBranch(ctx context.Context) (generation int) { + return tree.FindLongestTombstonedBranchFromLeaves(ctx, tree.GetLeaves()) } -func (tree RevTree) FindLongestTombstonedBranchFromLeaves(leaves []string) (generation int) { +func (tree RevTree) FindLongestTombstonedBranchFromLeaves(ctx context.Context, leaves []string) (generation int) { genLongestTSBranch := 0 for _, revid := range leaves { - gen := genOfRevID(revid) + gen := genOfRevID(ctx, revid) if tree[revid].Deleted { if gen > genLongestTSBranch { genLongestTSBranch = gen @@ -742,7 +742,7 @@ func (tree RevTree) getHistory(revid string) ([]string, error) { // ////// ENCODED REVISION LISTS (_revisions): // Parses a CouchDB _rev or _revisions property into a list of revision IDs -func ParseRevisions(body Body) []string { +func ParseRevisions(ctx context.Context, body Body) []string { // http://wiki.apache.org/couchdb/HTTP_Document_API#GET revisionsProperty, ok := body[BodyRevisions] @@ -751,7 +751,7 @@ func ParseRevisions(body Body) []string { if !ok { return nil } - if genOfRevID(revid) < 1 { + if genOfRevID(ctx, revid) < 1 { return nil } oneRev := make([]string, 0, 1) @@ -787,16 +787,16 @@ func splitRevisionList(revisions Revisions) (int, []string) { // Standard CouchDB encoding of a revision list: digests without numeric generation prefixes go in // the "ids" property, and the first (largest) generation number in the "start" property. // The docID parameter is informational only - and used when logging edge cases. -func encodeRevisions(docID string, revs []string) Revisions { +func encodeRevisions(ctx context.Context, docID string, revs []string) Revisions { ids := make([]string, len(revs)) var start int for i, revid := range revs { - gen, id := ParseRevID(revid) + gen, id := ParseRevID(ctx, revid) ids[i] = id if i == 0 { start = gen } else if gen != start-i { - base.DebugfCtx(context.TODO(), base.KeyCRUD, "Found gap in revision list for doc %q. Expecting gen %v but got %v in %v", base.UD(docID), start-i, gen, revs) + base.DebugfCtx(ctx, base.KeyCRUD, "Found gap in revision list for doc %q. Expecting gen %v but got %v in %v", base.UD(docID), start-i, gen, revs) } } return Revisions{RevisionsStart: start, RevisionsIds: ids} @@ -806,7 +806,7 @@ func encodeRevisions(docID string, revs []string) Revisions { // trim the history to stop at the first ancestor revID. If no ancestors are found, trim to // length maxUnmatchedLen. // TODO: Document/rename what the boolean result return value represents -func trimEncodedRevisionsToAncestor(revs Revisions, ancestors []string, maxUnmatchedLen int) (result bool, trimmedRevs Revisions) { +func trimEncodedRevisionsToAncestor(ctx context.Context, revs Revisions, ancestors []string, maxUnmatchedLen int) (result bool, trimmedRevs Revisions) { trimmedRevs = revs @@ -816,7 +816,7 @@ func trimEncodedRevisionsToAncestor(revs Revisions, ancestors []string, maxUnmat } matchIndex := len(digests) for _, revID := range ancestors { - gen, digest := ParseRevID(revID) + gen, digest := ParseRevID(ctx, revID) if index := start - gen; index >= 0 && index < matchIndex && digest == digests[index] { matchIndex = index maxUnmatchedLen = matchIndex + 1 diff --git a/db/revtree_test.go b/db/revtree_test.go index 9e2a3ca379..9e3f733fe0 100644 --- a/db/revtree_test.go +++ b/db/revtree_test.go @@ -9,6 +9,7 @@ package db import ( + "context" "fmt" "log" "os" @@ -55,7 +56,7 @@ type BranchSpec struct { // \ 3-b -- 4-b ... etc (losing branch) // // NOTE: the 1-a -- 2-a unconflicted branch can be longer, depending on value of unconflictedBranchNumRevs -func getTwoBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs, losingBranchNumRevs int, tombstoneLosingBranch bool) RevTree { +func getTwoBranchTestRevtree1(ctx context.Context, unconflictedBranchNumRevs, winningBranchNumRevs, losingBranchNumRevs int, tombstoneLosingBranch bool) RevTree { branchSpecs := []BranchSpec{ { @@ -65,7 +66,7 @@ func getTwoBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs, l }, } - return getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) + return getMultiBranchTestRevtree1(ctx, unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) } @@ -76,7 +77,7 @@ func getTwoBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs, l // \ 3-d -- 4-d ... etc (losing branch #n) // // NOTE: the 1-a -- 2-a unconflicted branch can be longer, depending on value of unconflictedBranchNumRevs -func getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs int, losingBranches []BranchSpec) RevTree { +func getMultiBranchTestRevtree1(ctx context.Context, unconflictedBranchNumRevs, winningBranchNumRevs int, losingBranches []BranchSpec) RevTree { if unconflictedBranchNumRevs < 1 { panic(fmt.Sprintf("Must have at least 1 unconflictedBranchNumRevs")) @@ -100,10 +101,10 @@ func getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs if err := base.JSONUnmarshal([]byte(testJSON), &revTree); err != nil { panic(fmt.Sprintf("Error: %v", err)) } - if unconflictedBranchNumRevs > 1 { // Add revs to unconflicted branch addRevs( + ctx, revTree, "1-winning", unconflictedBranchNumRevs-1, @@ -121,6 +122,7 @@ func getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs // Add revs to winning branch addRevs( + ctx, revTree, winningBranchStartRev, winningBranchNumRevs, @@ -140,6 +142,7 @@ func getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs // Add revs to losing branch addRevs( + ctx, revTree, losingBranchStartRev, losingBranchSpec.NumRevs, // Subtract 1 since we already added initial @@ -196,7 +199,7 @@ func TestGetMultiBranchTestRevtree(t *testing.T) { LastRevisionIsTombstone: true, }, } - revTree := getMultiBranchTestRevtree1(50, 100, branchSpecs) + revTree := getMultiBranchTestRevtree1(base.TestCtx(t), 50, 100, branchSpecs) leaves := revTree.GetLeaves() sort.Strings(leaves) assert.Equal(t, []string{"110-left", "150-winning", "76-right"}, leaves) @@ -299,11 +302,12 @@ func TestRevTreeAddRevisionWithMissingParent(t *testing.T) { } func TestRevTreeCompareRevIDs(t *testing.T) { - assert.Equal(t, 0, compareRevIDs("1-aaa", "1-aaa")) - assert.Equal(t, -1, compareRevIDs("1-aaa", "5-aaa")) - assert.Equal(t, 1, compareRevIDs("10-aaa", "5-aaa")) - assert.Equal(t, 1, compareRevIDs("1-bbb", "1-aaa")) - assert.Equal(t, 1, compareRevIDs("5-bbb", "1-zzz")) + ctx := base.TestCtx(t) + assert.Equal(t, 0, compareRevIDs(ctx, "1-aaa", "1-aaa")) + assert.Equal(t, -1, compareRevIDs(ctx, "1-aaa", "5-aaa")) + assert.Equal(t, 1, compareRevIDs(ctx, "10-aaa", "5-aaa")) + assert.Equal(t, 1, compareRevIDs(ctx, "1-bbb", "1-aaa")) + assert.Equal(t, 1, compareRevIDs(ctx, "5-bbb", "1-zzz")) } func TestRevTreeIsLeaf(t *testing.T) { @@ -315,20 +319,21 @@ func TestRevTreeIsLeaf(t *testing.T) { } func TestRevTreeWinningRev(t *testing.T) { + ctx := base.TestCtx(t) tempmap := branchymap.copy() - winner, branched, conflict := tempmap.winningRevision() + winner, branched, conflict := tempmap.winningRevision(ctx) assert.Equal(t, "3-three", winner) assert.True(t, branched) assert.True(t, conflict) err := tempmap.addRevision("testdoc", RevInfo{ID: "4-four", Parent: "3-three"}) require.NoError(t, err) - winner, branched, conflict = tempmap.winningRevision() + winner, branched, conflict = tempmap.winningRevision(ctx) assert.Equal(t, "4-four", winner) assert.True(t, branched) assert.True(t, conflict) err = tempmap.addRevision("testdoc", RevInfo{ID: "5-five", Parent: "4-four", Deleted: true}) require.NoError(t, err) - winner, branched, conflict = tempmap.winningRevision() + winner, branched, conflict = tempmap.winningRevision(ctx) assert.Equal(t, "3-drei", winner) assert.True(t, branched) assert.False(t, conflict) @@ -357,19 +362,20 @@ func TestPruneRevisions(t *testing.T) { assert.Equal(t, uint32(2), tempmap["2-two"].depth) assert.Equal(t, uint32(3), tempmap["1-one"].depth) + ctx := base.TestCtx(t) // Prune: - pruned, _ := tempmap.pruneRevisions(1000, "") + pruned, _ := tempmap.pruneRevisions(ctx, 1000, "") assert.Equal(t, 0, pruned) - pruned, _ = tempmap.pruneRevisions(3, "") + pruned, _ = tempmap.pruneRevisions(ctx, 3, "") assert.Equal(t, 0, pruned) - pruned, _ = tempmap.pruneRevisions(2, "") + pruned, _ = tempmap.pruneRevisions(ctx, 2, "") assert.Equal(t, 1, pruned) assert.Equal(t, 4, len(tempmap)) assert.Equal(t, (*RevInfo)(nil), tempmap["1-one"]) assert.Equal(t, "", tempmap["2-two"].Parent) // Make sure leaves are never pruned: - pruned, _ = tempmap.pruneRevisions(1, "") + pruned, _ = tempmap.pruneRevisions(ctx, 1, "") assert.Equal(t, 2, pruned) assert.Equal(t, 2, len(tempmap)) assert.True(t, tempmap["3-three"] != nil) @@ -382,13 +388,13 @@ func TestPruneRevisions(t *testing.T) { func TestPruneRevsSingleBranch(t *testing.T) { numRevs := 100 - - revTree := getMultiBranchTestRevtree1(numRevs, 0, []BranchSpec{}) + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, numRevs, 0, []BranchSpec{}) maxDepth := uint32(20) expectedNumPruned := numRevs - int(maxDepth) - numPruned, _ := revTree.pruneRevisions(maxDepth, "") + numPruned, _ := revTree.pruneRevisions(ctx, maxDepth, "") assert.Equal(t, expectedNumPruned, numPruned) } @@ -405,12 +411,12 @@ func TestPruneRevsOneWinningOneNonwinningBranch(t *testing.T) { unconflictedBranchNumRevs := 2 winningBranchNumRevs := 4 - - revTree := getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) maxDepth := uint32(2) - revTree.pruneRevisions(maxDepth, "") + revTree.pruneRevisions(ctx, maxDepth, "") assert.Equal(t, int(maxDepth), revTree.LongestBranch()) @@ -429,17 +435,18 @@ func TestPruneRevsOneWinningOneOldTombstonedBranch(t *testing.T) { unconflictedBranchNumRevs := 1 winningBranchNumRevs := 5 - revTree := getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) maxDepth := uint32(2) - revTree.pruneRevisions(maxDepth, "") + revTree.pruneRevisions(ctx, maxDepth, "") assert.True(t, revTree.LongestBranch() == int(maxDepth)) // we shouldn't have any tombstoned branches, since the tombstoned branch was so old // it should have been pruned away - assert.Equal(t, 0, revTree.FindLongestTombstonedBranch()) + assert.Equal(t, 0, revTree.FindLongestTombstonedBranch(ctx)) } @@ -461,11 +468,12 @@ func TestPruneRevsOneWinningOneOldAndOneRecentTombstonedBranch(t *testing.T) { unconflictedBranchNumRevs := 1 winningBranchNumRevs := 5 - revTree := getMultiBranchTestRevtree1(unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, unconflictedBranchNumRevs, winningBranchNumRevs, branchSpecs) maxDepth := uint32(2) - revTree.pruneRevisions(maxDepth, "") + revTree.pruneRevisions(ctx, maxDepth, "") assert.True(t, revTree.LongestBranch() == int(maxDepth)) @@ -485,7 +493,7 @@ func TestPruneRevsOneWinningOneOldAndOneRecentTombstonedBranch(t *testing.T) { // + // 1 extra rev in branchspec since LastRevisionIsTombstone (that variable name is misleading) expectedGenLongestTSd := 6 - assert.Equal(t, expectedGenLongestTSd, revTree.FindLongestTombstonedBranch()) + assert.Equal(t, expectedGenLongestTSd, revTree.FindLongestTombstonedBranch(ctx)) } @@ -504,9 +512,9 @@ func TestGenerationShortestNonTombstonedBranch(t *testing.T) { }, } - revTree := getMultiBranchTestRevtree1(3, 7, branchSpecs) + revTree := getMultiBranchTestRevtree1(base.TestCtx(t), 3, 7, branchSpecs) - generationShortestNonTombstonedBranch, _ := revTree.FindShortestNonTombstonedBranch() + generationShortestNonTombstonedBranch, _ := revTree.FindShortestNonTombstonedBranch(base.TestCtx(t)) // The "non-winning unresolved" branch has 7 revisions due to: // 3 unconflictedBranchNumRevs @@ -539,9 +547,9 @@ func TestGenerationLongestTombstonedBranch(t *testing.T) { LastRevisionIsTombstone: true, }, } - - revTree := getMultiBranchTestRevtree1(3, 7, branchSpecs) - generationLongestTombstonedBranch := revTree.FindLongestTombstonedBranch() + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, 3, 7, branchSpecs) + generationLongestTombstonedBranch := revTree.FindLongestTombstonedBranch(ctx) // The generation of the longest deleted branch is: // 3 unconflictedBranchNumRevs @@ -571,10 +579,11 @@ func TestPruneRevisionsPostIssue2651ThreeBranches(t *testing.T) { LastRevisionIsTombstone: true, }, } - revTree := getMultiBranchTestRevtree1(50, 100, branchSpecs) + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, 50, 100, branchSpecs) maxDepth := uint32(50) - numPruned, _ := revTree.pruneRevisions(maxDepth, "") + numPruned, _ := revTree.pruneRevisions(ctx, maxDepth, "") t.Logf("numPruned: %v", numPruned) t.Logf("LongestBranch: %v", revTree.LongestBranch()) @@ -594,7 +603,8 @@ func TestPruneRevsSingleTombstonedBranch(t *testing.T) { }, } - revTree := getMultiBranchTestRevtree1(1, 0, branchSpecs) + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, 1, 0, branchSpecs) log.Printf("RevTreeAfter before: %v", revTree.RenderGraphvizDot()) @@ -603,7 +613,7 @@ func TestPruneRevsSingleTombstonedBranch(t *testing.T) { expectedNumPruned += 1 // To account for the tombstone revision in the branchspec, which is spearate from NumRevs - numPruned, _ := revTree.pruneRevisions(maxDepth, "") + numPruned, _ := revTree.pruneRevisions(ctx, maxDepth, "") log.Printf("RevTreeAfter pruning: %v", revTree.RenderGraphvizDot()) @@ -625,7 +635,7 @@ func TestLongestBranch1(t *testing.T) { LastRevisionIsTombstone: true, }, } - revTree := getMultiBranchTestRevtree1(50, 100, branchSpecs) + revTree := getMultiBranchTestRevtree1(base.TestCtx(t), 50, 100, branchSpecs) assert.True(t, revTree.LongestBranch() == 150) @@ -652,7 +662,8 @@ func TestPruneDisconnectedRevTreeWithLongWinningBranch(t *testing.T) { LastRevisionIsTombstone: false, }, } - revTree := getMultiBranchTestRevtree1(1, 15, branchSpecs) + ctx := base.TestCtx(t) + revTree := getMultiBranchTestRevtree1(ctx, 1, 15, branchSpecs) if dumpRevTreeDotFiles { err := os.WriteFile("/tmp/TestPruneDisconnectedRevTreeWithLongWinningBranch_initial.dot", []byte(revTree.RenderGraphvizDot()), 0666) @@ -661,7 +672,7 @@ func TestPruneDisconnectedRevTreeWithLongWinningBranch(t *testing.T) { maxDepth := uint32(7) - revTree.pruneRevisions(maxDepth, "") + revTree.pruneRevisions(ctx, maxDepth, "") if dumpRevTreeDotFiles { err := os.WriteFile("/tmp/TestPruneDisconnectedRevTreeWithLongWinningBranch_pruned1.dot", []byte(revTree.RenderGraphvizDot()), 0666) @@ -672,6 +683,7 @@ func TestPruneDisconnectedRevTreeWithLongWinningBranch(t *testing.T) { // Add revs to winning branch addRevs( + ctx, revTree, winningBranchStartRev, 10, @@ -683,7 +695,7 @@ func TestPruneDisconnectedRevTreeWithLongWinningBranch(t *testing.T) { require.NoError(t, err) } - revTree.pruneRevisions(maxDepth, "") + revTree.pruneRevisions(base.TestCtx(t), maxDepth, "") if dumpRevTreeDotFiles { err := os.WriteFile("/tmp/TestPruneDisconnectedRevTreeWithLongWinningBranch_pruned_final.dot", []byte(revTree.RenderGraphvizDot()), 0666) @@ -721,7 +733,7 @@ func TestParseRevisions(t *testing.T) { var body Body unmarshalErr := body.Unmarshal([]byte(c.json)) assert.NoError(t, unmarshalErr, "base JSON in test case") - ids := ParseRevisions(body) + ids := ParseRevisions(base.TestCtx(t), body) assert.Equal(t, c.ids, ids) } } @@ -755,59 +767,60 @@ func BenchmarkEncodeRevisions(b *testing.B) { }, } + ctx := base.TestCtx(b) for _, test := range tests { docID := b.Name() + "-" + test.name b.Run(test.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = encodeRevisions(docID, test.input) + _ = encodeRevisions(ctx, docID, test.input) } }) } } func TestEncodeRevisions(t *testing.T) { - encoded := encodeRevisions(t.Name(), []string{"5-huey", "4-dewey", "3-louie"}) + encoded := encodeRevisions(base.TestCtx(t), t.Name(), []string{"5-huey", "4-dewey", "3-louie"}) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey", "dewey", "louie"}}, encoded) } func TestEncodeRevisionsGap(t *testing.T) { - encoded := encodeRevisions(t.Name(), []string{"5-huey", "3-louie"}) + encoded := encodeRevisions(base.TestCtx(t), t.Name(), []string{"5-huey", "3-louie"}) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey", "louie"}}, encoded) } func TestEncodeRevisionsZero(t *testing.T) { - encoded := encodeRevisions(t.Name(), []string{"1-foo", "0-bar"}) + encoded := encodeRevisions(base.TestCtx(t), t.Name(), []string{"1-foo", "0-bar"}) assert.Equal(t, Revisions{RevisionsStart: 1, RevisionsIds: []string{"foo", ""}}, encoded) } func TestTrimEncodedRevisionsToAncestor(t *testing.T) { + ctx := base.TestCtx(t) + encoded := encodeRevisions(ctx, t.Name(), []string{"5-huey", "4-dewey", "3-louie", "2-screwy"}) - encoded := encodeRevisions(t.Name(), []string{"5-huey", "4-dewey", "3-louie", "2-screwy"}) - - result, trimmedRevs := trimEncodedRevisionsToAncestor(encoded, []string{"3-walter", "17-gretchen", "1-fooey"}, 1000) + result, trimmedRevs := trimEncodedRevisionsToAncestor(ctx, encoded, []string{"3-walter", "17-gretchen", "1-fooey"}, 1000) assert.True(t, result) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey", "dewey", "louie", "screwy"}}, trimmedRevs) - result, trimmedRevs = trimEncodedRevisionsToAncestor(trimmedRevs, []string{"3-walter", "3-louie", "1-fooey"}, 2) + result, trimmedRevs = trimEncodedRevisionsToAncestor(ctx, trimmedRevs, []string{"3-walter", "3-louie", "1-fooey"}, 2) assert.True(t, result) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey", "dewey", "louie"}}, trimmedRevs) - result, trimmedRevs = trimEncodedRevisionsToAncestor(trimmedRevs, []string{"3-walter", "3-louie", "1-fooey"}, 3) + result, trimmedRevs = trimEncodedRevisionsToAncestor(ctx, trimmedRevs, []string{"3-walter", "3-louie", "1-fooey"}, 3) assert.True(t, result) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey", "dewey", "louie"}}, trimmedRevs) - result, trimmedRevs = trimEncodedRevisionsToAncestor(trimmedRevs, []string{"3-walter", "3-louie", "5-huey"}, 3) + result, trimmedRevs = trimEncodedRevisionsToAncestor(ctx, trimmedRevs, []string{"3-walter", "3-louie", "5-huey"}, 3) assert.True(t, result) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey"}}, trimmedRevs) // Check maxLength with no ancestors: - encoded = encodeRevisions(t.Name(), []string{"5-huey", "4-dewey", "3-louie", "2-screwy"}) + encoded = encodeRevisions(base.TestCtx(t), t.Name(), []string{"5-huey", "4-dewey", "3-louie", "2-screwy"}) - result, trimmedRevs = trimEncodedRevisionsToAncestor(encoded, nil, 6) + result, trimmedRevs = trimEncodedRevisionsToAncestor(ctx, encoded, nil, 6) assert.True(t, result) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey", "dewey", "louie", "screwy"}}, trimmedRevs) - result, trimmedRevs = trimEncodedRevisionsToAncestor(trimmedRevs, nil, 2) + result, trimmedRevs = trimEncodedRevisionsToAncestor(ctx, trimmedRevs, nil, 2) assert.True(t, result) assert.Equal(t, Revisions{RevisionsStart: 5, RevisionsIds: []string{"huey", "dewey"}}, trimmedRevs) } @@ -842,6 +855,7 @@ func TestRepairRevsHistoryWithCycles(t *testing.T) { base.SetUpTestLogging(t, base.LevelInfo, base.KeyCRUD) + ctx := base.TestCtx(t) for i, testdocProblematicRevTree := range testdocProblematicRevTrees { docId := "testdocProblematicRevTree" @@ -851,7 +865,7 @@ func TestRepairRevsHistoryWithCycles(t *testing.T) { t.Fatalf("Error unmarshalling doc %d: %v", i, err) } - if err := rawDoc.History.RepairCycles(); err != nil { + if err := rawDoc.History.RepairCycles(ctx); err != nil { t.Fatalf("Unable to repair doc. Err: %v", err) } @@ -923,14 +937,15 @@ func TestRevisionPruningLoop(t *testing.T) { // Keep adding to the main branch without pruning. Simulates old pruning algorithm, // which maintained rev history due to tombstone branch + ctx := base.TestCtx(t) for generation := 6; generation <= 15; generation++ { revID := fmt.Sprintf("%d-foo", generation) parentRevID := fmt.Sprintf("%d-foo", generation-1) - _, err := addPruneAndGet(revTree, revID, parentRevID, revBody, revsLimit, nonTombstone) + _, err := addPruneAndGet(ctx, revTree, revID, parentRevID, revBody, revsLimit, nonTombstone) assert.NoError(t, err, fmt.Sprintf("Error adding revision %s to tree", revID)) keepAliveRevID := fmt.Sprintf("%d-keep", generation) - _, err = addPruneAndGet(revTree, keepAliveRevID, parentRevID, revBody, revsLimit, tombstone) + _, err = addPruneAndGet(ctx, revTree, keepAliveRevID, parentRevID, revBody, revsLimit, tombstone) assert.NoError(t, err, fmt.Sprintf("Error adding revision %s to tree", revID)) // The act of marshalling the rev tree and then unmarshalling back into a revtree data structure @@ -986,7 +1001,7 @@ func TestPruneRevisionsWithDisconnected(t *testing.T) { "73-abc": {ID: "73-abc", Parent: "72-abc", Deleted: true}, } - prunedCount, _ := revTree.pruneRevisions(4, "") + prunedCount, _ := revTree.pruneRevisions(base.TestCtx(t), 4, "") assert.Equal(t, 10, prunedCount) remainingKeys := make([]string, 0, len(revTree)) @@ -998,14 +1013,14 @@ func TestPruneRevisionsWithDisconnected(t *testing.T) { assert.Equal(t, []string{"101-abc", "102-abc", "103-abc", "103-def", "104-abc", "105-abc", "106-abc", "106-def", "107-abc"}, remainingKeys) } -func addPruneAndGet(revTree RevTree, revID string, parentRevID string, revBody []byte, revsLimit uint32, tombstone bool) (numPruned int, err error) { +func addPruneAndGet(ctx context.Context, revTree RevTree, revID string, parentRevID string, revBody []byte, revsLimit uint32, tombstone bool) (numPruned int, err error) { _ = revTree.addRevision("doc", RevInfo{ ID: revID, Parent: parentRevID, Body: revBody, Deleted: tombstone, }) - numPruned, _ = revTree.pruneRevisions(revsLimit, revID) + numPruned, _ = revTree.pruneRevisions(ctx, revsLimit, revID) // Get history for new rev (checks for loops) history, err := revTree.getHistory(revID) @@ -1057,15 +1072,16 @@ func BenchmarkRevTreePruning(b *testing.B) { }, } + ctx := base.TestCtx(b) b.ResetTimer() for i := 0; i < b.N; i++ { b.StopTimer() - revTree := getMultiBranchTestRevtree1(50, 100, branchSpecs) + revTree := getMultiBranchTestRevtree1(ctx, 50, 100, branchSpecs) b.StartTimer() - revTree.pruneRevisions(50, "") + revTree.pruneRevisions(ctx, 50, "") } } @@ -1095,7 +1111,7 @@ func BenchmarkRevtreeUnmarshal(b *testing.B) { }) } -func addRevs(revTree RevTree, startingParentRevId string, numRevs int, revDigest string) { +func addRevs(ctx context.Context, revTree RevTree, startingParentRevId string, numRevs int, revDigest string) { docSizeBytes := 1024 * 5 body := createBodyContentAsMapWithSize(docSizeBytes) @@ -1106,7 +1122,7 @@ func addRevs(revTree RevTree, startingParentRevId string, numRevs int, revDigest channels := base.SetOf("ABC", "CBS") - generation, _ := ParseRevID(startingParentRevId) + generation, _ := ParseRevID(ctx, startingParentRevId) for i := 0; i < numRevs; i++ { diff --git a/db/sequence_allocator.go b/db/sequence_allocator.go index f2e6fbf40c..6fd06d6e0c 100644 --- a/db/sequence_allocator.go +++ b/db/sequence_allocator.go @@ -75,7 +75,7 @@ func newSequenceAllocator(ctx context.Context, datastore base.DataStore, dbStats defer base.FatalPanicHandler() s.releaseSequenceMonitor(ctx) }() - _, err := s.lastSequence() // just reads latest sequence from bucket + _, err := s.lastSequence(ctx) // just reads latest sequence from bucket return s, err } @@ -141,7 +141,7 @@ func (s *sequenceAllocator) releaseUnusedSequences(ctx context.Context) { // Retrieves the last allocated sequence. If there hasn't been an allocation yet by this node, // retrieves the value of the _sync:seq counter from the bucket by doing an incr(0) -func (s *sequenceAllocator) lastSequence() (uint64, error) { +func (s *sequenceAllocator) lastSequence(ctx context.Context) (uint64, error) { s.mutex.Lock() lastSeq := s.last s.mutex.Unlock() @@ -152,7 +152,7 @@ func (s *sequenceAllocator) lastSequence() (uint64, error) { s.dbStats.SequenceGetCount.Add(1) last, err := s.getSequence() if err != nil { - base.WarnfCtx(context.TODO(), "Error from Get in getSequence(): %v", err) + base.WarnfCtx(ctx, "Error from Get in getSequence(): %v", err) } return last, err } @@ -161,11 +161,11 @@ func (s *sequenceAllocator) lastSequence() (uint64, error) { // If previously reserved sequences are available (s.last < s.max), returns one // and increments s.last. // If no previously reserved sequences are available, reserves new batch. -func (s *sequenceAllocator) nextSequence() (sequence uint64, err error) { +func (s *sequenceAllocator) nextSequence(ctx context.Context) (sequence uint64, err error) { s.mutex.Lock() sequencesReserved := false if s.last >= s.max { - if err := s._reserveSequenceRange(); err != nil { + if err := s._reserveSequenceRange(ctx); err != nil { s.mutex.Unlock() return 0, err } @@ -186,7 +186,7 @@ func (s *sequenceAllocator) nextSequence() (sequence uint64, err error) { } // Reserve a new sequence range. Called by nextSequence when the previously allocated sequences have all been used. -func (s *sequenceAllocator) _reserveSequenceRange() error { +func (s *sequenceAllocator) _reserveSequenceRange(ctx context.Context) error { // If the time elapsed since the last reserveSequenceRange invocation reserve is shorter than our target frequency, // this indicates we're making an incr call more frequently than we want to. Triggers an increase in batch size to @@ -196,12 +196,12 @@ func (s *sequenceAllocator) _reserveSequenceRange() error { if s.sequenceBatchSize > maxBatchSize { s.sequenceBatchSize = maxBatchSize } - base.DebugfCtx(context.TODO(), base.KeyCRUD, "Increased sequence batch to %d", s.sequenceBatchSize) + base.DebugfCtx(ctx, base.KeyCRUD, "Increased sequence batch to %d", s.sequenceBatchSize) } max, err := s.incrementSequence(s.sequenceBatchSize) if err != nil { - base.WarnfCtx(context.TODO(), "Error from incrementSequence in _reserveSequences(%d): %v", s.sequenceBatchSize, err) + base.WarnfCtx(ctx, "Error from incrementSequence in _reserveSequences(%d): %v", s.sequenceBatchSize, err) return err } @@ -263,13 +263,13 @@ func (s *sequenceAllocator) releaseSequenceRange(ctx context.Context, fromSequen // waitForReleasedSequences blocks for 'releaseSequenceWait' past the provided startTime. // Used to guarantee assignment of allocated sequences on other nodes. -func (s *sequenceAllocator) waitForReleasedSequences(startTime time.Time) (waitedFor time.Duration) { +func (s *sequenceAllocator) waitForReleasedSequences(ctx context.Context, startTime time.Time) (waitedFor time.Duration) { requiredWait := s.releaseSequenceWait - time.Since(startTime) if requiredWait < 0 { return 0 } - base.InfofCtx(context.TODO(), base.KeyCache, "Waiting %v for sequence allocation...", requiredWait) + base.InfofCtx(ctx, base.KeyCache, "Waiting %v for sequence allocation...", requiredWait) time.Sleep(requiredWait) return requiredWait } diff --git a/db/sequence_allocator_test.go b/db/sequence_allocator_test.go index 574248e959..f791e6b6d0 100644 --- a/db/sequence_allocator_test.go +++ b/db/sequence_allocator_test.go @@ -46,42 +46,44 @@ func TestSequenceAllocator(t *testing.T) { defer func() { MaxSequenceIncrFrequency = oldFrequency }() MaxSequenceIncrFrequency = 60 * time.Second - initSequence, err := a.lastSequence() + ctx := base.TestCtx(t) + + initSequence, err := a.lastSequence(ctx) assert.Equal(t, uint64(0), initSequence) assert.NoError(t, err, "error retrieving last sequence") // Initial allocation should use batch size of 1 - nextSequence, err := a.nextSequence() + nextSequence, err := a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(1), nextSequence) assertNewAllocatorStats(t, testStats, 1, 1, 1, 0) // Subsequent allocation should increase batch size to 2, allocate 1 - nextSequence, err = a.nextSequence() + nextSequence, err = a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(2), nextSequence) assertNewAllocatorStats(t, testStats, 2, 3, 2, 0) // Subsequent allocation shouldn't trigger allocation - nextSequence, err = a.nextSequence() + nextSequence, err = a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(3), nextSequence) assertNewAllocatorStats(t, testStats, 2, 3, 3, 0) // Subsequent allocation should increase batch to 4, allocate 1 - nextSequence, err = a.nextSequence() + nextSequence, err = a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(4), nextSequence) assert.Equal(t, 4, int(a.sequenceBatchSize)) assertNewAllocatorStats(t, testStats, 3, 7, 4, 0) // Release unused sequences. Should reduce batch size to 1 (based on 3 unused) - a.releaseUnusedSequences(base.TestCtx(t)) + a.releaseUnusedSequences(ctx) assertNewAllocatorStats(t, testStats, 3, 7, 4, 3) assert.Equal(t, 1, int(a.sequenceBatchSize)) // Subsequent allocation should increase batch to 2, allocate 1 - nextSequence, err = a.nextSequence() + nextSequence, err = a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(8), nextSequence) assertNewAllocatorStats(t, testStats, 4, 9, 5, 3) @@ -103,20 +105,20 @@ func TestReleaseSequencesOnStop(t *testing.T) { oldFrequency := MaxSequenceIncrFrequency defer func() { MaxSequenceIncrFrequency = oldFrequency }() MaxSequenceIncrFrequency = 1000 * time.Millisecond - - a, err := newSequenceAllocator(base.TestCtx(t), bucket.GetSingleDataStore(), testStats, base.DefaultMetadataKeys) + ctx := base.TestCtx(t) + a, err := newSequenceAllocator(ctx, bucket.GetSingleDataStore(), testStats, base.DefaultMetadataKeys) // Reduce sequence wait for Stop testing a.releaseSequenceWait = 10 * time.Millisecond assert.NoError(t, err, "error creating allocator") // Initial allocation should use batch size of 1 - nextSequence, err := a.nextSequence() + nextSequence, err := a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(1), nextSequence) assertNewAllocatorStats(t, testStats, 1, 1, 1, 0) // Subsequent allocation should increase batch size to 2, allocate 1 - nextSequence, err = a.nextSequence() + nextSequence, err = a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(2), nextSequence) assertNewAllocatorStats(t, testStats, 2, 3, 2, 0) @@ -149,6 +151,7 @@ func TestSequenceAllocatorDeadlock(t *testing.T) { var err error var wg sync.WaitGroup + ctx := base.TestCtx(t) callbackCount := 0 incrCallback := func() { callbackCount++ @@ -160,7 +163,7 @@ func TestSequenceAllocatorDeadlock(t *testing.T) { for i := 0; i < 500; i++ { wg.Add(1) go func(a *sequenceAllocator) { - _, err := a.nextSequence() + _, err := a.nextSequence(ctx) assert.NoError(t, err) wg.Done() }(a) @@ -186,11 +189,11 @@ func TestSequenceAllocatorDeadlock(t *testing.T) { a.releaseSequenceWait = 10 * time.Millisecond assert.NoError(t, err, "error creating allocator") - nextSequence, err := a.nextSequence() + nextSequence, err := a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(1), nextSequence) - nextSequence, err = a.nextSequence() + nextSequence, err = a.nextSequence(ctx) assert.NoError(t, err) assert.Equal(t, uint64(2), nextSequence) @@ -208,13 +211,13 @@ func TestReleaseSequenceWait(t *testing.T) { dbstats, err := sgw.NewDBStats("", false, false, false, nil, nil) require.NoError(t, err) testStats := dbstats.Database() - - a, err := newSequenceAllocator(base.TestCtx(t), bucket.GetSingleDataStore(), testStats, base.DefaultMetadataKeys) + ctx := base.TestCtx(t) + a, err := newSequenceAllocator(ctx, bucket.GetSingleDataStore(), testStats, base.DefaultMetadataKeys) require.NoError(t, err) defer a.Stop(base.TestCtx(t)) startTime := time.Now().Add(-1 * time.Second) - amountWaited := a.waitForReleasedSequences(startTime) + amountWaited := a.waitForReleasedSequences(ctx, startTime) // Time will be a little less than a.releaseSequenceWait - 1*time.Second - validate // there's a non-zero wait that's less than releaseSequenceWait assert.True(t, amountWaited > 0) @@ -222,7 +225,7 @@ func TestReleaseSequenceWait(t *testing.T) { // Validate no wait for a time in the past longer than releaseSequenceWait noWaitTime := time.Now().Add(-5 * time.Second) - amountWaited = a.waitForReleasedSequences(noWaitTime) + amountWaited = a.waitForReleasedSequences(ctx, noWaitTime) assert.Equal(t, time.Duration(0), amountWaited) } diff --git a/db/sequence_id.go b/db/sequence_id.go index 2fa58fae2a..3afeabea97 100644 --- a/db/sequence_id.go +++ b/db/sequence_id.go @@ -68,14 +68,14 @@ func (s SequenceID) intSeqToString() string { } } -func seqStr(seq interface{}) string { +func seqStr(ctx context.Context, seq interface{}) string { switch seq := seq.(type) { case string: return seq case json.Number: return seq.String() } - base.WarnfCtx(context.Background(), "unknown seq type: %T", seq) + base.WarnfCtx(ctx, "unknown seq type: %T", seq) return "" } diff --git a/db/sg_replicate_cfg.go b/db/sg_replicate_cfg.go index 44062626ec..f5932cd2f1 100644 --- a/db/sg_replicate_cfg.go +++ b/db/sg_replicate_cfg.go @@ -257,7 +257,7 @@ func (rc *ReplicationConfig) validateFilteredChannels() error { // Upsert updates ReplicationConfig with any non-empty properties specified in the incoming replication config. // Note that if the intention is to reset the value to default, empty values must be specified. -func (rc *ReplicationConfig) Upsert(c *ReplicationUpsertConfig) { +func (rc *ReplicationConfig) Upsert(ctx context.Context, c *ReplicationUpsertConfig) { if c.Remote != nil { rc.Remote = *c.Remote @@ -351,7 +351,7 @@ func (rc *ReplicationConfig) Upsert(c *ReplicationUpsertConfig) { rc.QueryParams = newParamMap default: // unsupported query params type, don't upsert - base.WarnfCtx(context.Background(), "Unexpected QueryParams type found during upsert, will be ignored (%T): %v", c.QueryParams, c.QueryParams) + base.WarnfCtx(ctx, "Unexpected QueryParams type found during upsert, will be ignored (%T): %v", c.QueryParams, c.QueryParams) } } } @@ -371,7 +371,7 @@ func (rc *ReplicationConfig) Equals(compareToCfg *ReplicationConfig) (bool, erro // Redacted returns the ReplicationCfg with password of the remote database redacted from // both replication config and remote URL, i.e., any password will be replaced with xxxxx. -func (rc *ReplicationConfig) Redacted() *ReplicationConfig { +func (rc *ReplicationConfig) Redacted(ctx context.Context) *ReplicationConfig { config := *rc if config.Password != "" { config.Password = base.RedactedStr @@ -379,7 +379,7 @@ func (rc *ReplicationConfig) Redacted() *ReplicationConfig { if config.RemotePassword != "" { config.RemotePassword = base.RedactedStr } - config.Remote = base.RedactBasicAuthURLPassword(config.Remote) + config.Remote = base.RedactBasicAuthURLPassword(ctx, config.Remote) return &config } @@ -411,7 +411,7 @@ func (ar *ActiveReplicator) alignState(ctx context.Context, targetState string) return nil } - currentState, _ := ar.State() + currentState, _ := ar.State(ctx) if targetState == currentState { return nil } @@ -606,10 +606,10 @@ func (m *sgReplicateManager) NewActiveReplicatorConfig(config *ReplicationCfg) ( // Set conflict resolver for pull replications if rc.Direction == ActiveReplicatorTypePull || rc.Direction == ActiveReplicatorTypePushAndPull { if config.ConflictResolutionType == "" { - rc.ConflictResolverFunc, err = NewConflictResolverFunc(ConflictResolverDefault, "", m.dbContext.Options.JavascriptTimeout) + rc.ConflictResolverFunc, err = NewConflictResolverFunc(m.loggingCtx, ConflictResolverDefault, "", m.dbContext.Options.JavascriptTimeout) } else { - rc.ConflictResolverFunc, err = NewConflictResolverFunc(config.ConflictResolutionType, config.ConflictResolutionFn, m.dbContext.Options.JavascriptTimeout) + rc.ConflictResolverFunc, err = NewConflictResolverFunc(m.loggingCtx, config.ConflictResolutionType, config.ConflictResolutionFn, m.dbContext.Options.JavascriptTimeout) rc.ConflictResolverFuncSrc = config.ConflictResolutionFn } if err != nil { @@ -1071,7 +1071,7 @@ func (m *sgReplicateManager) PutReplications(replications map[string]*Replicatio } // PUT _replication/replicationID -func (m *sgReplicateManager) UpsertReplication(replication *ReplicationUpsertConfig) (created bool, err error) { +func (m *sgReplicateManager) UpsertReplication(ctx context.Context, replication *ReplicationUpsertConfig) (created bool, err error) { created = true addReplicationCallback := func(cluster *SGRCluster) (cancel bool, err error) { @@ -1079,7 +1079,7 @@ func (m *sgReplicateManager) UpsertReplication(replication *ReplicationUpsertCon if exists { created = false // If replication already exists ensure its in the stopped state before allowing upsert - state, err := m.GetReplicationStatus(replication.ID, DefaultReplicationStatusOptions()) + state, err := m.GetReplicationStatus(ctx, replication.ID, DefaultReplicationStatusOptions()) if err != nil { return true, err } @@ -1100,7 +1100,7 @@ func (m *sgReplicateManager) UpsertReplication(replication *ReplicationUpsertCon } } - cluster.Replications[replication.ID].Upsert(replication) + cluster.Replications[replication.ID].Upsert(ctx, replication) validateErr := cluster.Replications[replication.ID].ValidateReplication(false) if validateErr != nil { @@ -1392,7 +1392,7 @@ func DefaultReplicationStatusOptions() ReplicationStatusOptions { } } -func (m *sgReplicateManager) GetReplicationStatus(replicationID string, options ReplicationStatusOptions) (*ReplicationStatus, error) { +func (m *sgReplicateManager) GetReplicationStatus(ctx context.Context, replicationID string, options ReplicationStatusOptions) (*ReplicationStatus, error) { // Check if replication is assigned locally m.activeReplicatorsLock.RLock() @@ -1406,7 +1406,7 @@ func (m *sgReplicateManager) GetReplicationStatus(replicationID string, options var status *ReplicationStatus var remoteCfg *ReplicationCfg if isLocal { - status = replication.GetStatus() + status = replication.GetStatus(ctx) } else { // Attempt to retrieve persisted status var loadErr error @@ -1441,7 +1441,7 @@ func (m *sgReplicateManager) GetReplicationStatus(replicationID string, options return nil, err } } - status.Config = remoteCfg.ReplicationConfig.Redacted() + status.Config = remoteCfg.ReplicationConfig.Redacted(ctx) } if !options.IncludeError && status.Status == ReplicationStateError { @@ -1454,7 +1454,7 @@ func (m *sgReplicateManager) GetReplicationStatus(replicationID string, options return status, nil } -func (m *sgReplicateManager) PutReplicationStatus(replicationID, action string) (status *ReplicationStatus, err error) { +func (m *sgReplicateManager) PutReplicationStatus(ctx context.Context, replicationID, action string) (status *ReplicationStatus, err error) { targetState := "" switch action { @@ -1473,7 +1473,7 @@ func (m *sgReplicateManager) PutReplicationStatus(replicationID, action string) return nil, err } - updatedStatus, err := m.GetReplicationStatus(replicationID, DefaultReplicationStatusOptions()) + updatedStatus, err := m.GetReplicationStatus(ctx, replicationID, DefaultReplicationStatusOptions()) if err != nil { // Not found is expected when adhoc replication is stopped, return removed status instead of error // since UpdateReplicationState was successful @@ -1493,7 +1493,7 @@ func (m *sgReplicateManager) PutReplicationStatus(replicationID, action string) return updatedStatus, nil } -func (m *sgReplicateManager) GetReplicationStatusAll(options ReplicationStatusOptions) ([]*ReplicationStatus, error) { +func (m *sgReplicateManager) GetReplicationStatusAll(ctx context.Context, options ReplicationStatusOptions) ([]*ReplicationStatus, error) { statuses := make([]*ReplicationStatus, 0) @@ -1504,7 +1504,7 @@ func (m *sgReplicateManager) GetReplicationStatusAll(options ReplicationStatusOp } for replicationID, _ := range persistedReplications { - status, err := m.GetReplicationStatus(replicationID, options) + status, err := m.GetReplicationStatus(ctx, replicationID, options) if err != nil { base.WarnfCtx(m.loggingCtx, "Unable to retrieve replication status for replication %s", replicationID) } @@ -1555,7 +1555,7 @@ func (l *ReplicationHeartbeatListener) Name() string { } // When we detect other nodes have stopped pushing heartbeats, use manager to remove from cfg -func (l *ReplicationHeartbeatListener) StaleHeartbeatDetected(nodeUUID string) { +func (l *ReplicationHeartbeatListener) StaleHeartbeatDetected(_ context.Context, nodeUUID string) { base.InfofCtx(l.mgr.loggingCtx, base.KeyCluster, "StaleHeartbeatDetected by sg-replicate listener for node: %v", nodeUUID) err := l.mgr.RemoveNode(nodeUUID) diff --git a/db/sg_replicate_cfg_test.go b/db/sg_replicate_cfg_test.go index 982346dcd8..ad6b2db96f 100644 --- a/db/sg_replicate_cfg_test.go +++ b/db/sg_replicate_cfg_test.go @@ -26,10 +26,10 @@ func TestReplicateManagerReplications(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() - testCfg, err := base.NewCfgSG(testBucket.GetSingleDataStore(), "") + ctx := base.TestCtx(t) + testCfg, err := base.NewCfgSG(ctx, testBucket.GetSingleDataStore(), "") require.NoError(t, err) - ctx := base.TestCtx(t) manager, err := NewSGReplicateManager(ctx, &DatabaseContext{Name: "test"}, testCfg) require.NoError(t, err) defer manager.Stop() @@ -91,10 +91,10 @@ func TestReplicateManagerNodes(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() - testCfg, err := base.NewCfgSG(testBucket.GetSingleDataStore(), "") + ctx := base.TestCtx(t) + testCfg, err := base.NewCfgSG(ctx, testBucket.GetSingleDataStore(), "") require.NoError(t, err) - ctx := base.TestCtx(t) manager, err := NewSGReplicateManager(ctx, &DatabaseContext{Name: "test"}, testCfg) require.NoError(t, err) defer manager.Stop() @@ -148,7 +148,7 @@ func TestReplicateManagerConcurrentNodeOperations(t *testing.T) { defer testBucket.Close() ctx := base.TestCtx(t) - testCfg, err := base.NewCfgSG(testBucket.GetSingleDataStore(), "") + testCfg, err := base.NewCfgSG(ctx, testBucket.GetSingleDataStore(), "") require.NoError(t, err) manager, err := NewSGReplicateManager(ctx, &DatabaseContext{Name: "test"}, testCfg) @@ -193,7 +193,7 @@ func TestReplicateManagerConcurrentReplicationOperations(t *testing.T) { defer testBucket.Close() ctx := base.TestCtx(t) - testCfg, err := base.NewCfgSG(testBucket.GetSingleDataStore(), "") + testCfg, err := base.NewCfgSG(ctx, testBucket.GetSingleDataStore(), "") require.NoError(t, err) manager, err := NewSGReplicateManager(ctx, &DatabaseContext{Name: "test"}, testCfg) @@ -535,7 +535,7 @@ func TestUpsertReplicationConfig(t *testing.T) { } for _, testCase := range testCases { t.Run(fmt.Sprintf("%s", testCase.name), func(t *testing.T) { - testCase.existingConfig.Upsert(testCase.updatedConfig) + testCase.existingConfig.Upsert(base.TestCtx(t), testCase.updatedConfig) equal, err := testCase.existingConfig.Equals(testCase.expectedConfig) assert.NoError(t, err) assert.True(t, equal) @@ -643,10 +643,11 @@ func TestIsCfgChanged(t *testing.T) { testBucket := base.GetTestBucket(t) defer testBucket.Close() - testCfg, err := base.NewCfgSG(testBucket.GetSingleDataStore(), "") + ctx := base.TestCtx(t) + testCfg, err := base.NewCfgSG(ctx, testBucket.GetSingleDataStore(), "") require.NoError(t, err) - mgr, err := NewSGReplicateManager(base.TestCtx(t), &DatabaseContext{Name: "test"}, testCfg) + mgr, err := NewSGReplicateManager(ctx, &DatabaseContext{Name: "test"}, testCfg) require.NoError(t, err) defer mgr.Stop() @@ -656,7 +657,7 @@ func TestIsCfgChanged(t *testing.T) { replicatorConfig, err := mgr.NewActiveReplicatorConfig(replicationCfg) require.NoError(t, err) - replicationCfg.Upsert(testCase.updatedConfig) + replicationCfg.Upsert(base.TestCtx(t), testCase.updatedConfig) isChanged, err := mgr.isCfgChanged(replicationCfg, replicatorConfig) assert.NoError(t, err) diff --git a/db/sg_replicate_conflict_resolver.go b/db/sg_replicate_conflict_resolver.go index 8305590eb9..df7da371b5 100644 --- a/db/sg_replicate_conflict_resolver.go +++ b/db/sg_replicate_conflict_resolver.go @@ -62,7 +62,7 @@ type Conflict struct { // based on a merge of the two. // - In the merge case, winner[revid] must be empty. // - If an nil Body is returned, the conflict should be resolved as a deletion/tombstone. -type ConflictResolverFunc func(conflict Conflict) (winner Body, err error) +type ConflictResolverFunc func(ctx context.Context, conflict Conflict) (winner Body, err error) type ConflictResolverStats struct { ConflictResultMergeCount *base.SgwIntStat @@ -107,9 +107,9 @@ func NewConflictResolver(crf ConflictResolverFunc, statsContainer *base.DbReplic // Wrapper for ConflictResolverFunc that evaluates whether conflict resolution resulted in // localWins, remoteWins, or merge -func (c *ConflictResolver) Resolve(conflict Conflict) (winner Body, resolutionType ConflictResolutionType, err error) { +func (c *ConflictResolver) Resolve(ctx context.Context, conflict Conflict) (winner Body, resolutionType ConflictResolutionType, err error) { - winner, err = c.crf(conflict) + winner, err = c.crf(ctx, conflict) if err != nil { return winner, "", err } @@ -132,7 +132,7 @@ func (c *ConflictResolver) Resolve(conflict Conflict) (winner Body, resolutionTy return winner, ConflictResolutionRemote, nil } - base.InfofCtx(context.Background(), base.KeyReplicate, "Conflict resolver returned non-empty revID (%s) not matching local (%s) or remote (%s), treating result as merge.", winningRev, localRev, remoteRev) + base.InfofCtx(ctx, base.KeyReplicate, "Conflict resolver returned non-empty revID (%s) not matching local (%s) or remote (%s), treating result as merge.", winningRev, localRev, remoteRev) c.stats.ConflictResultMergeCount.Add(1) return winner, ConflictResolutionMerge, err } @@ -141,7 +141,7 @@ func (c *ConflictResolver) Resolve(conflict Conflict) (winner Body, resolutionTy // with the exception that a deleted revision is picked as the winner: // the revision whose (deleted, generation, hash) tuple compares the highest. // Returns error to satisfy ConflictResolverFunc signature. -func DefaultConflictResolver(conflict Conflict) (result Body, err error) { +func DefaultConflictResolver(ctx context.Context, conflict Conflict) (result Body, err error) { localDeleted, _ := conflict.LocalDocument[BodyDeleted].(bool) remoteDeleted, _ := conflict.RemoteDocument[BodyDeleted].(bool) if localDeleted && !remoteDeleted { @@ -153,7 +153,7 @@ func DefaultConflictResolver(conflict Conflict) (result Body, err error) { localRevID, _ := conflict.LocalDocument[BodyRev].(string) remoteRevID, _ := conflict.RemoteDocument[BodyRev].(string) - if compareRevIDs(localRevID, remoteRevID) >= 0 { + if compareRevIDs(ctx, localRevID, remoteRevID) >= 0 { return conflict.LocalDocument, nil } else { return conflict.RemoteDocument, nil @@ -161,16 +161,16 @@ func DefaultConflictResolver(conflict Conflict) (result Body, err error) { } // LocalWinsConflictResolver returns the local document as winner -func LocalWinsConflictResolver(conflict Conflict) (winner Body, err error) { +func LocalWinsConflictResolver(_ context.Context, conflict Conflict) (winner Body, err error) { return conflict.LocalDocument, nil } // RemoteWinsConflictResolver returns the local document as-is -func RemoteWinsConflictResolver(conflict Conflict) (winner Body, err error) { +func RemoteWinsConflictResolver(_ context.Context, conflict Conflict) (winner Body, err error) { return conflict.RemoteDocument, nil } -func NewConflictResolverFunc(resolverType ConflictResolverType, customResolverSource string, customResolverTimeout time.Duration) (ConflictResolverFunc, error) { +func NewConflictResolverFunc(ctx context.Context, resolverType ConflictResolverType, customResolverSource string, customResolverTimeout time.Duration) (ConflictResolverFunc, error) { switch resolverType { case ConflictResolverLocalWins: return LocalWinsConflictResolver, nil @@ -179,7 +179,7 @@ func NewConflictResolverFunc(resolverType ConflictResolverType, customResolverSo case ConflictResolverDefault: return DefaultConflictResolver, nil case ConflictResolverCustom: - return NewCustomConflictResolver(customResolverSource, customResolverTimeout) + return NewCustomConflictResolver(ctx, customResolverSource, customResolverTimeout) default: return nil, fmt.Errorf("Unknown Conflict Resolver type: %s", resolverType) } @@ -187,8 +187,8 @@ func NewConflictResolverFunc(resolverType ConflictResolverType, customResolverSo // NewCustomConflictResolver returns a ConflictResolverFunc that executes the // javascript conflict resolver specified by source -func NewCustomConflictResolver(source string, timeout time.Duration) (ConflictResolverFunc, error) { - conflictResolverJSServer := NewConflictResolverJSServer(source, timeout) +func NewCustomConflictResolver(ctx context.Context, source string, timeout time.Duration) (ConflictResolverFunc, error) { + conflictResolverJSServer := NewConflictResolverJSServer(ctx, source, timeout) return conflictResolverJSServer.EvaluateFunction, nil } @@ -197,21 +197,21 @@ type ConflictResolverJSServer struct { *sgbucket.JSServer } -func NewConflictResolverJSServer(fnSource string, timeout time.Duration) *ConflictResolverJSServer { - base.DebugfCtx(context.Background(), base.KeyReplicate, "Creating new ConflictResolverFunction") +func NewConflictResolverJSServer(ctx context.Context, fnSource string, timeout time.Duration) *ConflictResolverJSServer { + base.DebugfCtx(ctx, base.KeyReplicate, "Creating new ConflictResolverFunction") return &ConflictResolverJSServer{ JSServer: sgbucket.NewJSServer(fnSource, timeout, kTaskCacheSize, newConflictResolverRunner), } } // EvaluateFunction executes the conflict resolver with the provided conflict and returns the result. -func (i *ConflictResolverJSServer) EvaluateFunction(conflict Conflict) (Body, error) { +func (i *ConflictResolverJSServer) EvaluateFunction(ctx context.Context, conflict Conflict) (Body, error) { docID, _ := conflict.LocalDocument[BodyId].(string) localRevID, _ := conflict.LocalDocument[BodyRev].(string) remoteRevID, _ := conflict.RemoteDocument[BodyRev].(string) result, err := i.Call(conflict) if err != nil { - base.WarnfCtx(context.Background(), "Unexpected error invoking conflict resolver for document %s, local/remote revisions %s/%s - processing aborted, document will not be replicated. Error: %v", + base.WarnfCtx(ctx, "Unexpected error invoking conflict resolver for document %s, local/remote revisions %s/%s - processing aborted, document will not be replicated. Error: %v", base.UD(docID), base.UD(localRevID), base.UD(remoteRevID), err) return nil, err } @@ -227,23 +227,24 @@ func (i *ConflictResolverJSServer) EvaluateFunction(conflict Conflict) (Body, er case map[string]interface{}: return result, nil case error: - base.WarnfCtx(context.Background(), "conflictResolverRunner: "+result.Error()) + base.WarnfCtx(ctx, "conflictResolverRunner: "+result.Error()) return nil, result default: - base.WarnfCtx(context.Background(), "Custom conflict resolution function returned non-document result %v Type: %T", result, result) + base.WarnfCtx(ctx, "Custom conflict resolution function returned non-document result %v Type: %T", result, result) return nil, errors.New("Custom conflict resolution function returned non-document value.") } } // Compiles a JavaScript event function to a conflictResolverRunner object. func newConflictResolverRunner(funcSource string, timeout time.Duration) (sgbucket.JSServerTask, error) { + ctx := context.TODO() // fix in sg-bucket conflictResolverRunner := &sgbucket.JSRunner{} err := conflictResolverRunner.InitWithLogging(funcSource, timeout, func(s string) { - base.ErrorfCtx(context.Background(), base.KeyJavascript.String()+": ConflictResolver %s", base.UD(s)) + base.ErrorfCtx(ctx, base.KeyJavascript.String()+": ConflictResolver %s", base.UD(s)) }, func(s string) { - base.InfofCtx(context.Background(), base.KeyJavascript, "ConflictResolver %s", base.UD(s)) + base.InfofCtx(ctx, base.KeyJavascript, "ConflictResolver %s", base.UD(s)) }) if err != nil { return nil, err @@ -252,27 +253,27 @@ func newConflictResolverRunner(funcSource string, timeout time.Duration) (sgbuck // Implementation of the 'defaultPolicy(conflict)' callback: conflictResolverRunner.DefineNativeFunction("defaultPolicy", func(call otto.FunctionCall) otto.Value { if len(call.ArgumentList) == 0 { - return ErrorToOttoValue(conflictResolverRunner, errors.New("No conflict parameter specified when calling defaultPolicy()")) + return ErrorToOttoValue(ctx, conflictResolverRunner, errors.New("No conflict parameter specified when calling defaultPolicy()")) } rawConflict, exportErr := call.Argument(0).Export() if exportErr != nil { - return ErrorToOttoValue(conflictResolverRunner, fmt.Errorf("Unable to export conflict parameter for defaultPolicy(): %v Error: %s", call.Argument(0), exportErr)) + return ErrorToOttoValue(ctx, conflictResolverRunner, fmt.Errorf("Unable to export conflict parameter for defaultPolicy(): %v Error: %s", call.Argument(0), exportErr)) } // Called defaultPolicy with null/undefined value - return if rawConflict == nil || call.Argument(0).IsUndefined() { - return ErrorToOttoValue(conflictResolverRunner, errors.New("Null or undefined value passed to defaultPolicy()")) + return ErrorToOttoValue(ctx, conflictResolverRunner, errors.New("Null or undefined value passed to defaultPolicy()")) } conflict, ok := rawConflict.(Conflict) if !ok { - return ErrorToOttoValue(conflictResolverRunner, fmt.Errorf("Invalid value passed to defaultPolicy(). Value was type %T, expected type Conflict", rawConflict)) + return ErrorToOttoValue(ctx, conflictResolverRunner, fmt.Errorf("Invalid value passed to defaultPolicy(). Value was type %T, expected type Conflict", rawConflict)) } - defaultWinner, _ := DefaultConflictResolver(conflict) + defaultWinner, _ := DefaultConflictResolver(ctx, conflict) ottoDefaultWinner, err := conflictResolverRunner.ToValue(defaultWinner) if err != nil { - return ErrorToOttoValue(conflictResolverRunner, fmt.Errorf("Error converting default winner to javascript value. Error:%w", err)) + return ErrorToOttoValue(ctx, conflictResolverRunner, fmt.Errorf("Error converting default winner to javascript value. Error:%w", err)) } return ottoDefaultWinner }) @@ -286,10 +287,10 @@ func newConflictResolverRunner(funcSource string, timeout time.Duration) (sgbuck } // Converts an error to an otto value, to support native functions returning errors. -func ErrorToOttoValue(runner *sgbucket.JSRunner, err error) otto.Value { +func ErrorToOttoValue(ctx context.Context, runner *sgbucket.JSRunner, err error) otto.Value { errorValue, convertErr := runner.ToValue(err) if convertErr != nil { - base.WarnfCtx(context.Background(), "Unable to convert error to otto value: %v", convertErr) + base.WarnfCtx(ctx, "Unable to convert error to otto value: %v", convertErr) } return errorValue } diff --git a/db/sg_replicate_conflict_resolver_test.go b/db/sg_replicate_conflict_resolver_test.go index b3818bc315..38e15f0ee8 100644 --- a/db/sg_replicate_conflict_resolver_test.go +++ b/db/sg_replicate_conflict_resolver_test.go @@ -13,6 +13,7 @@ package db import ( "testing" + "github.com/couchbase/sync_gateway/base" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -63,7 +64,7 @@ func TestDefaultConflictResolver(t *testing.T) { LocalDocument: test.localDocument, RemoteDocument: test.remoteDocument, } - result, err := DefaultConflictResolver(conflict) + result, err := DefaultConflictResolver(base.TestCtx(t), conflict) assert.NoError(t, err) assert.Equal(tt, test.expectedWinner, result) }) @@ -158,9 +159,10 @@ func TestCustomConflictResolver(t *testing.T) { LocalDocument: test.localDocument, RemoteDocument: test.remoteDocument, } - customConflictResolverFunc, err := NewCustomConflictResolver(test.resolverSource, 0) + ctx := base.TestCtx(t) + customConflictResolverFunc, err := NewCustomConflictResolver(ctx, test.resolverSource, 0) require.NoError(tt, err) - result, err := customConflictResolverFunc(conflict) + result, err := customConflictResolverFunc(ctx, conflict) if test.expectError { assert.Error(t, err) return diff --git a/db/users.go b/db/users.go index aa0c94fbc1..92afbac80b 100644 --- a/db/users.go +++ b/db/users.go @@ -33,7 +33,7 @@ func (db *DatabaseContext) DeleteRole(ctx context.Context, name string, purge bo return base.ErrNotFound } - seq, err := db.sequences.nextSequence() + seq, err := db.sequences.nextSequence(ctx) if err != nil { return err } @@ -183,7 +183,7 @@ func (dbc *DatabaseContext) UpdatePrincipal(ctx context.Context, updates *auth.P // Update the persistent sequence number of this principal (only allocate a sequence when needed - issue #673): nextSeq := uint64(0) - nextSeq, err = dbc.sequences.nextSequence() + nextSeq, err = dbc.sequences.nextSequence(ctx) if err != nil { return replaced, err } diff --git a/db/util_testing.go b/db/util_testing.go index 0645188bdb..65af82a8c7 100644 --- a/db/util_testing.go +++ b/db/util_testing.go @@ -39,6 +39,7 @@ func WaitForPrimaryIndexEmpty(ctx context.Context, store base.N1QLStore) error { // Kick off the retry loop err, _ := base.RetryLoop( + ctx, "Wait for index to be empty", retryWorker, base.CreateMaxDoublingSleeperFunc(60, 500, 5000), @@ -486,8 +487,8 @@ func (dbc *DatabaseContext) GetPrincipalForTest(tb testing.TB, name string, isUs } // TestBucketPoolWithIndexes runs a TestMain for packages that require creation of indexes -func TestBucketPoolWithIndexes(m *testing.M, tbpOptions base.TestBucketPoolOptions) { - base.TestBucketPoolMain(m, viewsAndGSIBucketReadier, viewsAndGSIBucketInit, tbpOptions) +func TestBucketPoolWithIndexes(ctx context.Context, m *testing.M, tbpOptions base.TestBucketPoolOptions) { + base.TestBucketPoolMain(ctx, m, viewsAndGSIBucketReadier, viewsAndGSIBucketInit, tbpOptions) } // Parse the plan looking for use of the fetch operation (appears as the key/value pair "#operator":"Fetch") diff --git a/db/validation.go b/db/validation.go index 2b991a6ec4..de64211166 100644 --- a/db/validation.go +++ b/db/validation.go @@ -10,6 +10,7 @@ package db import ( "bytes" + "context" "net/http" "strings" @@ -72,13 +73,13 @@ func validateImportBody(body Body) error { // validateBlipBody validates incoming blip rev bodies // Takes a rawBody to avoid an unnecessary call to doc.BodyBytes() -func validateBlipBody(rawBody []byte, doc *Document) error { +func validateBlipBody(ctx context.Context, rawBody []byte, doc *Document) error { // Prevent disallowed internal properties from being used disallowed := []string{base.SyncPropertyName, BodyId, BodyRev, BodyDeleted, BodyRevisions} for _, prop := range disallowed { // Only unmarshal if raw body contains the disallowed property if bytes.Contains(rawBody, []byte(`"`+prop+`"`)) { - if _, ok := doc.Body()[prop]; ok { + if _, ok := doc.Body(ctx)[prop]; ok { return base.HTTPErrorf(http.StatusBadRequest, "top-level property '"+prop+"' is a reserved internal property") } } diff --git a/rest/access_test.go b/rest/access_test.go index c5cb65ca3d..ff2df2e94e 100644 --- a/rest/access_test.go +++ b/rest/access_test.go @@ -92,7 +92,7 @@ func TestStarAccess(t *testing.T) { rt := NewRestTester(t, &RestTesterConfig{SyncFn: channels.DocChannelsSyncFunction}) defer rt.Close() - a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions(rt.Context())) a.Collections = rt.GetDatabase().CollectionNames var changes struct { Results []db.ChangeEntry @@ -306,7 +306,7 @@ func TestNumAccessErrors(t *testing.T) { response := rt.SendUserRequest("PUT", "/{{.keyspace}}/doc", `{"prop":true, "channels":["foo"]}`, "user") RequireStatus(t, response, 403) - base.WaitForStat(func() int64 { return rt.GetDatabase().DbStats.SecurityStats.NumAccessErrors.Value() }, 1) + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.SecurityStats.NumAccessErrors.Value() }, 1) } func TestUserHasDocAccessDocNotFound(t *testing.T) { rt := NewRestTester(t, &RestTesterConfig{ @@ -602,7 +602,7 @@ func TestAllDocsAccessControl(t *testing.T) { } // Create some docs: - a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions(rt.Context())) a.Collections = rt.GetDatabase().CollectionNames guest, err := a.GetUser("") assert.NoError(t, err) diff --git a/rest/admin_api.go b/rest/admin_api.go index 2118bef306..ad5ba7d5e2 100644 --- a/rest/admin_api.go +++ b/rest/admin_api.go @@ -10,6 +10,7 @@ package rest import ( "bytes" + "context" "errors" "fmt" "io" @@ -63,7 +64,7 @@ func (h *handler) handleCreateDB() error { return base.HTTPErrorf(http.StatusBadRequest, err.Error()) } - version, err := GenerateDatabaseConfigVersionID("", config) + version, err := GenerateDatabaseConfigVersionID(h.ctx(), "", config) if err != nil { return err } @@ -87,7 +88,7 @@ func (h *handler) handleCreateDB() error { dbCreds, _ := h.server.Config.DatabaseCredentials[dbName] bucketCreds, _ := h.server.Config.BucketCredentials[bucket] - if err := config.setup(dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { + if err := config.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { return err } @@ -156,7 +157,7 @@ func (h *handler) handleCreateDB() error { h.server.dbConfigs[dbName].cfgCas = cas } else { // Intentionally pass in an empty BootstrapConfig to avoid inheriting any credentials or server when running with a legacy config (CBG-1764) - if err := config.setup(dbName, BootstrapConfig{}, nil, nil, false); err != nil { + if err := config.setup(h.ctx(), dbName, BootstrapConfig{}, nil, nil, false); err != nil { return err } @@ -174,12 +175,12 @@ func (h *handler) handleCreateDB() error { // getAuthScopeHandleCreateDB is used in the router to supply an auth scope for the admin api auth. Takes the JSON body // from the payload, pulls out bucket and returns this as the auth scope. -func getAuthScopeHandleCreateDB(bodyJSON []byte) (string, error) { +func getAuthScopeHandleCreateDB(ctx context.Context, bodyJSON []byte) (string, error) { var body struct { Bucket string `json:"bucket"` } reader := bytes.NewReader(bodyJSON) - err := DecodeAndSanitiseConfig(reader, &body, false) + err := DecodeAndSanitiseConfig(ctx, reader, &body, false) if err != nil { return "", err } @@ -296,7 +297,7 @@ func (h *handler) handleGetDbConfig() error { } var err error - responseConfig, err = responseConfig.Redacted() + responseConfig, err = responseConfig.Redacted(h.ctx()) if err != nil { return err } @@ -355,7 +356,7 @@ func (h *handler) handleGetConfig() error { return err } - databaseMap[dbName], err = dbConfig.Redacted() + databaseMap[dbName], err = dbConfig.Redacted(h.ctx()) if err != nil { return err } @@ -375,7 +376,7 @@ func (h *handler) handleGetConfig() error { dbConfig.Replications = make(map[string]*db.ReplicationConfig, len(replications)) for replicationName, replicationConfig := range replications { - dbConfig.Replications[replicationName] = replicationConfig.ReplicationConfig.Redacted() + dbConfig.Replications[replicationName] = replicationConfig.ReplicationConfig.Redacted(h.ctx()) } } @@ -440,7 +441,7 @@ func (h *handler) handlePutConfig() error { testMap[key] = true } - base.UpdateLogKeys(testMap, true) + base.UpdateLogKeys(h.ctx(), testMap, true) } } @@ -559,7 +560,7 @@ func (h *handler) handlePutDbConfig() (err error) { } dbCreds, _ := h.server.Config.DatabaseCredentials[dbName] - if err := updatedDbConfig.setup(dbName, h.server.Config.Bootstrap, dbCreds, nil, false); err != nil { + if err := updatedDbConfig.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, nil, false); err != nil { return err } if err := h.server.ReloadDatabaseWithConfig(contextNoCancel, *updatedDbConfig, false); err != nil { @@ -592,7 +593,7 @@ func (h *handler) handlePutDbConfig() (err error) { return nil, base.HTTPErrorf(http.StatusBadRequest, err.Error()) } - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(h.ctx(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } @@ -608,7 +609,7 @@ func (h *handler) handlePutDbConfig() (err error) { tmpConfig.cfgCas = bucketDbConfig.cfgCas dbCreds, _ := h.server.Config.DatabaseCredentials[dbName] bucketCreds, _ := h.server.Config.BucketCredentials[bucket] - if err := tmpConfig.setup(dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { + if err := tmpConfig.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { return nil, err } @@ -698,7 +699,7 @@ func (h *handler) handleDeleteCollectionConfigSync() error { bucketDbConfig.Sync = nil } - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(h.ctx(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } @@ -716,7 +717,7 @@ func (h *handler) handleDeleteCollectionConfigSync() error { dbName := h.db.Name dbCreds, _ := h.server.Config.DatabaseCredentials[dbName] bucketCreds, _ := h.server.Config.BucketCredentials[bucket] - if err := updatedDbConfig.setup(dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { + if err := updatedDbConfig.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { return err } @@ -764,7 +765,7 @@ func (h *handler) handlePutCollectionConfigSync() error { return nil, base.HTTPErrorf(http.StatusBadRequest, err.Error()) } - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(h.ctx(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } @@ -780,7 +781,7 @@ func (h *handler) handlePutCollectionConfigSync() error { dbCreds, _ := h.server.Config.DatabaseCredentials[dbName] bucketCreds, _ := h.server.Config.BucketCredentials[bucket] - if err := updatedDbConfig.setup(dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { + if err := updatedDbConfig.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { return err } @@ -859,7 +860,7 @@ func (h *handler) handleDeleteCollectionConfigImportFilter() error { bucketDbConfig.ImportFilter = nil } - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(h.ctx(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } @@ -875,7 +876,7 @@ func (h *handler) handleDeleteCollectionConfigImportFilter() error { dbCreds, _ := h.server.Config.DatabaseCredentials[dbName] bucketCreds, _ := h.server.Config.BucketCredentials[bucket] - if err := updatedDbConfig.setup(dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { + if err := updatedDbConfig.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { return err } @@ -924,7 +925,7 @@ func (h *handler) handlePutCollectionConfigImportFilter() error { return nil, base.HTTPErrorf(http.StatusBadRequest, err.Error()) } - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(h.ctx(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } @@ -940,7 +941,7 @@ func (h *handler) handlePutCollectionConfigImportFilter() error { dbCreds, _ := h.server.Config.DatabaseCredentials[dbName] bucketCreds, _ := h.server.Config.BucketCredentials[bucket] - if err := updatedDbConfig.setup(dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { + if err := updatedDbConfig.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { return err } @@ -964,7 +965,7 @@ func (h *handler) handleDeleteDB() error { var bucket string if h.server.persistentConfig { - bucket, _ = h.server.bucketNameFromDbName(dbName) + bucket, _ = h.server.bucketNameFromDbName(h.ctx(), dbName) err := h.server.BootstrapContext.DeleteConfig(h.ctx(), bucket, h.server.Config.Bootstrap.ConfigGroupID, dbName) if err != nil { return base.HTTPErrorf(http.StatusInternalServerError, "couldn't remove database %q from bucket %q: %s", base.MD(dbName), base.MD(bucket), err.Error()) @@ -1010,7 +1011,7 @@ func (h *handler) handleGetRawDoc() error { if doc.IsDeleted() { rawBytes = []byte(db.DeletedDocument) } else { - docRawBodyBytes, err := doc.BodyBytes() + docRawBodyBytes, err := doc.BodyBytes(h.ctx()) if err != nil { return err } @@ -1100,10 +1101,10 @@ func (h *handler) handleGetStatus() error { // Don't bother trying to lookup LastSequence() if offline if runState != db.RunStateString[db.DBOffline] { - lastSeq, _ = database.LastSequence() + lastSeq, _ = database.LastSequence(h.ctx()) } - replicationsStatus, err := database.SGReplicateMgr.GetReplicationStatusAll(db.DefaultReplicationStatusOptions()) + replicationsStatus, err := database.SGReplicateMgr.GetReplicationStatusAll(h.ctx(), db.DefaultReplicationStatusOptions()) if err != nil { return err } @@ -1112,7 +1113,7 @@ func (h *handler) handleGetStatus() error { return err } for _, replication := range cluster.Replications { - replication.ReplicationConfig = *replication.Redacted() + replication.ReplicationConfig = *replication.Redacted(h.ctx()) } status.Databases[database.Name] = DatabaseStatus{ @@ -1182,7 +1183,7 @@ func (h *handler) handleSetLogging() error { return base.HTTPErrorf(http.StatusBadRequest, "Invalid JSON or non-boolean values for log key map") } - base.UpdateLogKeys(keys, h.rq.Method == "PUT") + base.UpdateLogKeys(h.ctx(), keys, h.rq.Method == "PUT") return nil } @@ -1573,7 +1574,7 @@ func (h *handler) handlePurge() error { } if len(docIDs) > 0 { - count := h.collection.RemoveFromChangeCache(docIDs, startTime) + count := h.collection.RemoveFromChangeCache(h.ctx(), docIDs, startTime) base.DebugfCtx(h.ctx(), base.KeyCache, "Purged %d items from caches", count) } @@ -1596,7 +1597,7 @@ func (h *handler) getReplications() error { } else { replication.AssignedNode = replication.AssignedNode + " (non-local)" } - replication.ReplicationConfig = *replication.Redacted() + replication.ReplicationConfig = *replication.Redacted(h.ctx()) } h.writeJSON(replications) @@ -1613,7 +1614,7 @@ func (h *handler) getReplication() error { return err } - h.writeJSON(replication.Redacted()) + h.writeJSON(replication.Redacted(h.ctx())) return nil } @@ -1638,7 +1639,7 @@ func (h *handler) putReplication() error { replicationConfig.ID = replicationID } - created, err := h.db.SGReplicateMgr.UpsertReplication(replicationConfig) + created, err := h.db.SGReplicateMgr.UpsertReplication(h.ctx(), replicationConfig) if err != nil { return err } @@ -1655,7 +1656,7 @@ func (h *handler) deleteReplication() error { } func (h *handler) getReplicationsStatus() error { - replicationsStatus, err := h.db.SGReplicateMgr.GetReplicationStatusAll(h.getReplicationStatusOptions()) + replicationsStatus, err := h.db.SGReplicateMgr.GetReplicationStatusAll(h.ctx(), h.getReplicationStatusOptions()) if err != nil { return err } @@ -1665,7 +1666,7 @@ func (h *handler) getReplicationsStatus() error { func (h *handler) getReplicationStatus() error { replicationID := mux.Vars(h.rq)["replicationID"] - status, err := h.db.SGReplicateMgr.GetReplicationStatus(replicationID, h.getReplicationStatusOptions()) + status, err := h.db.SGReplicateMgr.GetReplicationStatus(h.ctx(), replicationID, h.getReplicationStatusOptions()) if err != nil { return err } @@ -1694,7 +1695,7 @@ func (h *handler) putReplicationStatus() error { return base.HTTPErrorf(http.StatusBadRequest, "Query parameter 'action' must be specified") } - updatedStatus, err := h.db.SGReplicateMgr.PutReplicationStatus(replicationID, action) + updatedStatus, err := h.db.SGReplicateMgr.PutReplicationStatus(h.ctx(), replicationID, action) if err != nil { return err } diff --git a/rest/admin_api_auth_test.go b/rest/admin_api_auth_test.go index 47963895f4..e678968cdb 100644 --- a/rest/admin_api_auth_test.go +++ b/rest/admin_api_auth_test.go @@ -129,7 +129,7 @@ func TestCheckPermissions(t *testing.T) { defer DeleteUser(t, httpClient, eps[0], testCase.CreateUser) } - statusCode, permResults, err := CheckPermissions(httpClient, eps, "", testCase.Username, testCase.Password, testCase.RequestPermissions, testCase.ResponsePermissions) + statusCode, permResults, err := CheckPermissions(base.TestCtx(t), httpClient, eps, "", testCase.Username, testCase.Password, testCase.RequestPermissions, testCase.ResponsePermissions) require.NoError(t, err) assert.Equal(t, testCase.ExpectedStatusCode, statusCode) assert.True(t, reflect.DeepEqual(testCase.ExpectedPermissionResults, permResults)) @@ -160,14 +160,14 @@ func TestCheckPermissionsWithX509(t *testing.T) { require.NoError(t, err) svrctx.GoCBAgent = goCBAgent - noX509HttpClient, err := svrctx.initializeNoX509HttpClient() + noX509HttpClient, err := svrctx.initializeNoX509HttpClient(ctx) require.NoError(t, err) svrctx.NoX509HTTPClient = noX509HttpClient eps, httpClient, err := svrctx.ObtainManagementEndpointsAndHTTPClient() assert.NoError(t, err) - statusCode, _, err := CheckPermissions(httpClient, eps, "", base.TestClusterUsername(), base.TestClusterPassword(), []Permission{Permission{"!admin", false}}, nil) + statusCode, _, err := CheckPermissions(ctx, httpClient, eps, "", base.TestClusterUsername(), base.TestClusterPassword(), []Permission{Permission{"!admin", false}}, nil) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) @@ -282,7 +282,7 @@ func TestCheckRoles(t *testing.T) { defer DeleteUser(t, httpClient, eps[0], testCase.CreateUser) } - statusCode, err := CheckRoles(httpClient, eps, testCase.Username, testCase.Password, testCase.RequestRoles, testCase.BucketName) + statusCode, err := CheckRoles(base.TestCtx(t), httpClient, eps, testCase.Username, testCase.Password, testCase.RequestRoles, testCase.BucketName) require.NoError(t, err) assert.Equal(t, testCase.ExpectedStatusCode, statusCode) }) @@ -421,7 +421,7 @@ func TestAdminAuth(t *testing.T) { defer DeleteUser(t, httpClient, managementEndpoints[0], testCase.CreateUser) } - permResults, statusCode, err := checkAdminAuth(testCase.BucketName, testCase.Username, testCase.Password, testCase.Operation, httpClient, managementEndpoints, true, testCase.CheckPermissions, testCase.ResponsePermissions) + permResults, statusCode, err := checkAdminAuth(base.TestCtx(t), testCase.BucketName, testCase.Username, testCase.Password, testCase.Operation, httpClient, managementEndpoints, true, testCase.CheckPermissions, testCase.ResponsePermissions) assert.NoError(t, err) assert.Equal(t, testCase.ExpectedStatusCode, statusCode) @@ -458,7 +458,7 @@ func TestAdminAuthWithX509(t *testing.T) { require.NoError(t, err) svrctx.GoCBAgent = goCBAgent - noX509HttpClient, err := svrctx.initializeNoX509HttpClient() + noX509HttpClient, err := svrctx.initializeNoX509HttpClient(ctx) require.NoError(t, err) svrctx.NoX509HTTPClient = noX509HttpClient @@ -466,11 +466,11 @@ func TestAdminAuthWithX509(t *testing.T) { require.NoError(t, err) var statusCode int - _, statusCode, err = checkAdminAuth("", base.TestClusterUsername(), base.TestClusterPassword(), "", httpClient, managementEndpoints, true, []Permission{{"!admin", false}}, nil) + _, statusCode, err = checkAdminAuth(ctx, "", base.TestClusterUsername(), base.TestClusterPassword(), "", httpClient, managementEndpoints, true, []Permission{{"!admin", false}}, nil) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) - _, statusCode, err = checkAdminAuth("", "invalidUser", "invalidPassword", "", httpClient, managementEndpoints, true, []Permission{{"!admin", false}}, nil) + _, statusCode, err = checkAdminAuth(ctx, "", "invalidUser", "invalidPassword", "", httpClient, managementEndpoints, true, []Permission{{"!admin", false}}, nil) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, statusCode) require.Contains(t, err.Error(), ErrInvalidLogin.Message) @@ -988,7 +988,7 @@ func TestDisablePermissionCheck(t *testing.T) { } defer DeleteUser(t, httpClient, eps[0], testCase.CreateUser) - _, statusCode, err := checkAdminAuth(rt.Bucket().GetName(), testCase.CreateUser, "password", "", httpClient, eps, testCase.DoPermissionCheck, testCase.RequirePerms, nil) + _, statusCode, err := checkAdminAuth(rt.Context(), rt.Bucket().GetName(), testCase.CreateUser, "password", "", httpClient, eps, testCase.DoPermissionCheck, testCase.RequirePerms, nil) assert.NoError(t, err) assert.Equal(t, testCase.ExpectedStatusCode, statusCode) diff --git a/rest/adminapitest/admin_api_test.go b/rest/adminapitest/admin_api_test.go index d147649ba1..8695cd4a3d 100644 --- a/rest/adminapitest/admin_api_test.go +++ b/rest/adminapitest/admin_api_test.go @@ -129,7 +129,7 @@ func TestNoPanicInvalidUpdate(t *testing.T) { t.Fatalf("Error unmarshalling response: %v", err) } revId := responseDoc["rev"].(string) - revGeneration, revIdHash := db.ParseRevID(revId) + revGeneration, revIdHash := db.ParseRevID(rt.Context(), revId) assert.Equal(t, 1, revGeneration) // Update doc (normal update, no conflicting revisions added) @@ -147,7 +147,7 @@ func TestNoPanicInvalidUpdate(t *testing.T) { t.Fatalf("Error unmarshalling response: %v", err) } revId = responseDoc["rev"].(string) - revGeneration, _ = db.ParseRevID(revId) + revGeneration, _ = db.ParseRevID(rt.Context(), revId) assert.Equal(t, 2, revGeneration) // Create conflict again, should be a no-op and return the same response as previous attempt @@ -157,7 +157,7 @@ func TestNoPanicInvalidUpdate(t *testing.T) { t.Fatalf("Error unmarshalling response: %v", err) } revId = responseDoc["rev"].(string) - revGeneration, _ = db.ParseRevID(revId) + revGeneration, _ = db.ParseRevID(rt.Context(), revId) assert.Equal(t, 2, revGeneration) } @@ -1019,7 +1019,7 @@ func TestResyncForNamedCollection(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and add new scopes and collections to it. tb := base.GetTestBucket(t) @@ -1153,7 +1153,7 @@ func TestResyncUsingDCPStreamForNamedCollection(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and add new scopes and collections to it. tb := base.GetTestBucket(t) @@ -1561,7 +1561,7 @@ func TestCorruptDbConfigHandling(t *testing.T) { // grab the persisted db config from the bucket databaseConfig := rest.DatabaseConfig{} - _, err := rt.ServerContext().BootstrapContext.GetConfig(rt.CustomTestBucket.GetName(), rt.ServerContext().Config.Bootstrap.ConfigGroupID, "db1", &databaseConfig) + _, err := rt.ServerContext().BootstrapContext.GetConfig(rt.Context(), rt.CustomTestBucket.GetName(), rt.ServerContext().Config.Bootstrap.ConfigGroupID, "db1", &databaseConfig) require.NoError(t, err) // update the persisted config to a fake bucket name @@ -1639,7 +1639,7 @@ func TestBadConfigInsertionToBucket(t *testing.T) { dbConfig := rt.NewDbConfig() dbConfig.Name = "db1" - version, err := rest.GenerateDatabaseConfigVersionID("", &dbConfig) + version, err := rest.GenerateDatabaseConfigVersionID(rt.Context(), "", &dbConfig) require.NoError(t, err) metadataID, metadataIDError := rt.ServerContext().BootstrapContext.ComputeMetadataIDForDbConfig(base.TestCtx(t), &dbConfig) @@ -1877,7 +1877,7 @@ func TestMultipleBucketWithBadDbConfigScenario3(t *testing.T) { rest.RequireStatus(t, resp, http.StatusCreated) // persistence logic construction - version, err := rest.GenerateDatabaseConfigVersionID("", &dbConfig) + version, err := rest.GenerateDatabaseConfigVersionID(rt.Context(), "", &dbConfig) require.NoError(rt.TB, err) metadataID, metadataIDError := rt.ServerContext().BootstrapContext.ComputeMetadataIDForDbConfig(base.TestCtx(rt.TB), &dbConfig) @@ -3291,7 +3291,7 @@ func TestPersistentConfigConcurrency(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -3354,7 +3354,7 @@ func TestDbConfigCredentials(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -3419,7 +3419,7 @@ func TestInvalidDBConfig(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -3473,7 +3473,7 @@ func TestCreateDbOnNonExistentBucket(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) resp := rest.BootstrapAdminRequest(t, http.MethodPut, "/db/", `{"bucket": "nonexistentbucket"}`) resp.RequireStatus(http.StatusForbidden) @@ -3505,7 +3505,7 @@ func TestPutDbConfigChangeName(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -3549,7 +3549,7 @@ func TestSwitchDbConfigCollectionName(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and add new scopes and collections to it. tb := base.GetTestBucket(t) @@ -3624,7 +3624,7 @@ func TestPutDBConfigOIDC(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -3751,7 +3751,7 @@ func TestConfigsIncludeDefaults(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) resp := rest.BootstrapAdminRequest(t, http.MethodPut, "/db/", fmt.Sprintf( @@ -3834,7 +3834,7 @@ func TestLegacyCredentialInheritance(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -3913,7 +3913,7 @@ func TestDbOfflineConfigPersistent(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -3988,7 +3988,7 @@ func TestDbConfigPersistentSGVersions(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -4008,7 +4008,7 @@ func TestDbConfigPersistentSGVersions(t *testing.T) { RevsLimit: base.Uint32Ptr(123), // use RevsLimit to detect config changes }, } - dbConfig.Version, err = rest.GenerateDatabaseConfigVersionID("", &dbConfig.DbConfig) + dbConfig.Version, err = rest.GenerateDatabaseConfigVersionID(ctx, "", &dbConfig.DbConfig) require.NoError(t, err) // initialise with db config @@ -4037,7 +4037,7 @@ func TestDbConfigPersistentSGVersions(t *testing.T) { db.SGVersion = version db.DbConfig.RevsLimit = base.Uint32Ptr(revsLimit) - db.Version, err = rest.GenerateDatabaseConfigVersionID(db.Version, &db.DbConfig) + db.Version, err = rest.GenerateDatabaseConfigVersionID(ctx, db.Version, &db.DbConfig) if err != nil { return nil, err } @@ -4078,7 +4078,7 @@ func TestDbConfigPersistentSGVersions(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) assertRevsLimit(sc, 654) @@ -4103,7 +4103,7 @@ func TestDeleteFunctionsWhileDbOffline(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) defer func() { sc.Close(ctx) require.NoError(t, <-serverErr) @@ -4191,7 +4191,7 @@ func TestSetFunctionsWhileDbOffline(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) defer func() { sc.Close(ctx) require.NoError(t, <-serverErr) @@ -4306,7 +4306,7 @@ func TestEmptyStringJavascriptFunctions(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -4405,7 +4405,7 @@ func TestDeleteDatabasePointingAtSameBucketPersistent(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) defer func() { @@ -4454,7 +4454,7 @@ func TestDeleteDatabasePointingAtSameBucketPersistent(t *testing.T) { } func BootstrapWaitForDatabaseState(t *testing.T, dbName string, state uint32) { - err := base.WaitForNoError(func() error { + err := base.WaitForNoError(base.TestCtx(t), func() error { resp := rest.BootstrapAdminRequest(t, http.MethodGet, "/"+dbName+"/", "") if resp.StatusCode != http.StatusOK { return errors.New("expected 200 status") @@ -4699,7 +4699,7 @@ func TestTombstoneCompactionPurgeInterval(t *testing.T) { ctx := rt.Context() cbStore, _ := base.AsCouchbaseBucketStore(rt.Bucket()) - serverPurgeInterval, err := cbStore.MetadataPurgeInterval() + serverPurgeInterval, err := cbStore.MetadataPurgeInterval(ctx) require.NoError(t, err) // Set server purge interval back to what it was for bucket reuse defer setServerPurgeInterval(t, rt, fmt.Sprintf("%.2f", serverPurgeInterval.Hours()/24)) @@ -4755,9 +4755,9 @@ func TestPerDBCredsOverride(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) - couchbaseCluster, err := rest.CreateCouchbaseClusterFromStartupConfig(sc.Config, base.PerUseClusterConnections) + couchbaseCluster, err := rest.CreateCouchbaseClusterFromStartupConfig(ctx, sc.Config, base.PerUseClusterConnections) require.NoError(t, err) sc.BootstrapContext.Connection = couchbaseCluster diff --git a/rest/adminapitest/main_test.go b/rest/adminapitest/main_test.go index eece6a6b24..bb3c527c0b 100644 --- a/rest/adminapitest/main_test.go +++ b/rest/adminapitest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package adminapitest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,6 +19,7 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 8192} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/rest/api.go b/rest/api.go index ac600245c4..80c8987b0f 100644 --- a/rest/api.go +++ b/rest/api.go @@ -421,7 +421,7 @@ func (h *handler) handleGetDB() error { // Don't bother trying to lookup LastSequence() if offline runState := db.RunStateString[atomic.LoadUint32(&h.db.State)] if runState != db.RunStateString[db.DBOffline] { - lastSeq, _ := h.db.LastSequence() + lastSeq, _ := h.db.LastSequence(h.ctx()) defaultCollectionLastSeq = &lastSeq } diff --git a/rest/api_benchmark_test.go b/rest/api_benchmark_test.go index e7491a6c56..f13e4edb2d 100644 --- a/rest/api_benchmark_test.go +++ b/rest/api_benchmark_test.go @@ -208,7 +208,7 @@ func BenchmarkReadOps_Changes(b *testing.B) { var body db.Body require.NoError(b, base.JSONUnmarshal(response.Body.Bytes(), &body)) revid := body["rev"].(string) - _, rev1_digest := db.ParseRevID(revid) + _, rev1_digest := db.ParseRevID(rt.Context(), revid) response = rt.SendAdminRequest("PUT", fmt.Sprintf("/{{.keyspace}}/doc1k?rev=%s", revid), doc1k_putDoc) if response.Code != 201 { log.Printf("Unexpected add rev response: %d %s", response.Code, response.Body.Bytes()) diff --git a/rest/api_collections_test.go b/rest/api_collections_test.go index 001c8ac6e2..d5bf2f1692 100644 --- a/rest/api_collections_test.go +++ b/rest/api_collections_test.go @@ -720,7 +720,7 @@ func TestCollectionsChangeConfigScope(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Create a DB configured with one scope res := BootstrapAdminRequest(t, http.MethodPut, "/db/", string(mustMarshalJSON(t, map[string]any{ @@ -886,7 +886,7 @@ func TestCollectionStats(t *testing.T) { ok, err := dbc.GetCollectionDatastore().AddRaw("importeddoc", 0, []byte(`{"imported":true}`)) require.NoError(t, err) assert.True(t, ok) - base.WaitForStat(collection2Stats.ImportCount.Value, 1) + base.RequireWaitForStat(t, collection2Stats.ImportCount.Value, 1) assert.Equal(t, int64(2), collection2Stats.NumDocWrites.Value()) } } diff --git a/rest/api_test.go b/rest/api_test.go index a2c505cca7..b427bbaf64 100644 --- a/rest/api_test.go +++ b/rest/api_test.go @@ -74,42 +74,37 @@ func TestPublicRESTStatCount(t *testing.T) { // create a user to authenticate as for public api calls and assert the stat hasn't incremented as a result rt.CreateUser("greg", []string{"ABC"}) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.NumPublicRestRequests.Value() }, 0) - require.True(t, ok) // use public api to put a doc through SGW then assert the stat has increased by 1 resp := rt.SendUserRequest(http.MethodPut, "/{{.keyspace}}/doc1", `{"foo":"bar", "channels":["ABC"]}`, "greg") RequireStatus(t, resp, http.StatusCreated) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.NumPublicRestRequests.Value() }, 1) - require.True(t, ok) // send admin request assert that the public rest count doesn't increase resp = rt.SendAdminRequest(http.MethodGet, "/{{.keyspace}}/doc1", "") RequireStatus(t, resp, http.StatusOK) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.NumPublicRestRequests.Value() }, 1) - require.True(t, ok) // send another public request to assert the stat increases by 1 resp = rt.SendUserRequest(http.MethodGet, "/{{.keyspace}}/doc1", "", "greg") RequireStatus(t, resp, http.StatusOK) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.NumPublicRestRequests.Value() }, 2) - require.True(t, ok) resp = rt.SendUserRequest(http.MethodGet, "/{{.db}}/_blipsync", "", "greg") RequireStatus(t, resp, http.StatusUpgradeRequired) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.NumPublicRestRequests.Value() }, 2) - require.True(t, ok) srv := httptest.NewServer(rt.TestMetricsHandler()) defer srv.Close() @@ -120,19 +115,17 @@ func TestPublicRESTStatCount(t *testing.T) { require.NoError(t, err) assert.Equal(t, http.StatusOK, response.StatusCode) // assert the stat doesn't increment - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.NumPublicRestRequests.Value() }, 2) - require.True(t, ok) // test public endpoint but one that doesn't access a db resp = rt.SendUserRequest(http.MethodGet, "/", "", "greg") RequireStatus(t, resp, http.StatusOK) // assert the stat doesn't increment - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.NumPublicRestRequests.Value() }, 2) - require.True(t, ok) } func TestDBRoot(t *testing.T) { @@ -1444,7 +1437,7 @@ func TestEventConfigValidationInvalid(t *testing.T) { buf := bytes.NewBufferString(dbConfigJSON) var dbConfig DbConfig - err := DecodeAndSanitiseConfig(buf, &dbConfig, true) + err := DecodeAndSanitiseConfig(base.TestCtx(t), buf, &dbConfig, true) require.Error(t, err) assert.Contains(t, err.Error(), "document_scribbled_on") } @@ -2746,7 +2739,7 @@ func TestNullDocHandlingForMutable1xBody(t *testing.T) { documentRev := db.DocumentRevision{DocID: "doc1", BodyBytes: []byte("null")} - body, err := documentRev.Mutable1xBody(collection, nil, nil, false) + body, err := documentRev.Mutable1xBody(rt.Context(), collection, nil, nil, false) require.Error(t, err) require.Nil(t, body) assert.Contains(t, err.Error(), "null doc body for doc") diff --git a/rest/api_test_no_race_test.go b/rest/api_test_no_race_test.go index b2ae748f02..8612552022 100644 --- a/rest/api_test_no_race_test.go +++ b/rest/api_test_no_race_test.go @@ -196,7 +196,7 @@ func TestSetupAndValidate(t *testing.T) { }`)) defer deleteTempFile(t, configFile) args := []string{"sync_gateway", configFile.Name()} - config, err := setupServerConfig(args) + config, err := setupServerConfig(base.TestCtx(t), args) require.NoError(t, err, "Error reading config file") require.NotNil(t, config) @@ -243,7 +243,7 @@ func TestSetupAndValidate(t *testing.T) { configFile := createTempFile(t, []byte(`{"unknownKey":"unknownValue"}`)) defer deleteTempFile(t, configFile) args := []string{"sync_gateway", configFile.Name()} - config, err := setupServerConfig(args) + config, err := setupServerConfig(base.TestCtx(t), args) require.Error(t, err, "Should throw error reading file") assert.Contains(t, err.Error(), "unrecognized JSON field") assert.Nil(t, config) @@ -253,7 +253,7 @@ func TestSetupAndValidate(t *testing.T) { configFile := createTempFile(t, []byte(``)) args := []string{"sync_gateway", configFile.Name()} deleteTempFile(t, configFile) - config, err := setupServerConfig(args) + config, err := setupServerConfig(base.TestCtx(t), args) require.Error(t, err, "Should throw error reading file") assert.Contains(t, err.Error(), "Error reading config file") assert.Nil(t, config) @@ -276,7 +276,7 @@ func TestSetupAndValidate(t *testing.T) { }`)) defer deleteTempFile(t, configFile) args := []string{"sync_gateway", configFile.Name()} - config, err := setupServerConfig(args) + config, err := setupServerConfig(base.TestCtx(t), args) require.Error(t, err, "Should throw error reading file") assert.Contains(t, err.Error(), "minimum value for unsupported.stats_log_freq_secs is: 10") assert.Nil(t, config) diff --git a/rest/attachment_test.go b/rest/attachment_test.go index 4ab478fa92..210c9d78a6 100644 --- a/rest/attachment_test.go +++ b/rest/attachment_test.go @@ -2441,7 +2441,7 @@ func TestAttachmentRemovalWithConflicts(t *testing.T) { losingRev3 := RespRevID(t, resp) // Create doc conflicting with previous revid referencing previous attachment too - _, revIDHash := db.ParseRevID(revid) + _, revIDHash := db.ParseRevID(rt.Context(), revid) resp = rt.SendAdminRequest("PUT", "/{{.keyspace}}/doc?new_edits=false", `{"_rev": "3-b", "_revisions": {"ids": ["b", "`+revIDHash+`"], "start": 3}, "_attachments": {"hello.txt": {"revpos":2,"stub":true,"digest":"sha1-Kq5sNclPz7QV2+lfQIuc6R7oRu0="}}, "Winning Rev": true}`) RequireStatus(t, resp, http.StatusCreated) winningRev3 := RespRevID(t, resp) diff --git a/rest/attachmentcompactiontest/attachment_compaction_api_test.go b/rest/attachmentcompactiontest/attachment_compaction_api_test.go index 39efe0c24e..d1f165fd70 100644 --- a/rest/attachmentcompactiontest/attachment_compaction_api_test.go +++ b/rest/attachmentcompactiontest/attachment_compaction_api_test.go @@ -454,11 +454,11 @@ func TestAttachmentCompactionMarkPhaseRollback(t *testing.T) { name := db.GenerateCompactionDCPStreamName(stat.CompactID, "mark") checkpointPrefix := fmt.Sprintf("%s:%v", "_sync:dcp_ck:", name) - meta := base.NewDCPMetadataCS(dataStore, 1024, 8, checkpointPrefix) + meta := base.NewDCPMetadataCS(rt.Context(), dataStore, 1024, 8, checkpointPrefix) vbMeta := meta.GetMeta(0) vbMeta.VbUUID = garbageVBUUID meta.SetMeta(0, vbMeta) - meta.Persist(0, []uint16{0}) + meta.Persist(rt.Context(), 0, []uint16{0}) // kick off a new run attempting to start it again (should force into rollback handling) resp = rt.SendAdminRequest("POST", "/{{.db}}/_compact?type=attachment&action=start", "") diff --git a/rest/attachmentcompactiontest/main_test.go b/rest/attachmentcompactiontest/main_test.go index 13663c4b76..e7beaada5a 100644 --- a/rest/attachmentcompactiontest/main_test.go +++ b/rest/attachmentcompactiontest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package attachmentcompactiontest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,6 +19,7 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/rest/blip_api_crud_test.go b/rest/blip_api_crud_test.go index 2742ebdefc..e2365f2703 100644 --- a/rest/blip_api_crud_test.go +++ b/rest/blip_api_crud_test.go @@ -2454,8 +2454,7 @@ func TestProcessRevIncrementsStat(t *testing.T) { err = activeRT.WaitForRev("doc", rev) require.NoError(t, err) - _, ok := base.WaitForStat(pullStats.HandleRevCount.Value, 1) - require.True(t, ok) + base.RequireWaitForStat(t, pullStats.HandleRevCount.Value, 1) assert.NotEqualValues(t, 0, pullStats.HandleRevBytes.Value()) // Confirm connected client count has not increased, which uses same processRev code assert.EqualValues(t, 0, pullStats.HandlePutRevCount.Value()) @@ -2631,8 +2630,7 @@ func TestUnsubChanges(t *testing.T) { assert.NoError(t, err) assert.Empty(t, response) // Wait for unsub changes to stop the sub changes being sent before sending document up - activeReplVal, _ := base.WaitForStat(activeReplStat.Value, 0) - assert.EqualValues(t, 0, activeReplVal) + base.RequireWaitForStat(t, activeReplStat.Value, 0) // Confirm no more changes are being sent resp = rt.UpdateDoc("doc2", "", `{"key":"val1"}`) diff --git a/rest/blip_client_test.go b/rest/blip_client_test.go index cf4e2b554c..edc8474af5 100644 --- a/rest/blip_client_test.go +++ b/rest/blip_client_test.go @@ -12,7 +12,6 @@ package rest import ( "bytes" - "context" "encoding/base64" "fmt" "net/http" @@ -97,6 +96,7 @@ func (btr *BlipTesterReplicator) initHandlers(btc *BlipTesterClient) { btr.replicationStats = db.NewBlipSyncStats() } + ctx := base.DatabaseLogCtx(base.TestCtx(btr.bt.restTester.TB), btr.bt.restTester.GetDatabase().Name, nil) btr.bt.blipContext.HandlerForProfile[db.MessageProveAttachment] = func(msg *blip.Message) { btr.storeMessage(msg) @@ -121,7 +121,7 @@ func (btr *BlipTesterReplicator) initHandlers(btc *BlipTesterClient) { panic(fmt.Sprintf("error getting client attachment: %v", err)) } - proof := db.ProveAttachment(attData, nonce) + proof := db.ProveAttachment(ctx, attData, nonce) resp := msg.Response() resp.SetBody([]byte(proof)) @@ -335,7 +335,7 @@ func (btr *BlipTesterReplicator) initHandlers(btc *BlipTesterClient) { if err != nil { panic(err) } - nonce, proof, err := db.GenerateProofOfAttachment(attData) + nonce, proof, err := db.GenerateProofOfAttachment(ctx, attData) if err != nil { panic(err) } @@ -446,14 +446,14 @@ func (btr *BlipTesterReplicator) initHandlers(btc *BlipTesterClient) { digest, ok := msg.Properties[db.GetAttachmentDigest] if !ok { - base.PanicfCtx(context.TODO(), "couldn't find digest in getAttachment message properties") + base.PanicfCtx(ctx, "couldn't find digest in getAttachment message properties") } btcr := btc.getCollectionClientFromMessage(msg) attachment, err := btcr.getAttachment(digest) if err != nil { - base.PanicfCtx(context.TODO(), "couldn't find attachment for digest: %v", digest) + base.PanicfCtx(ctx, "couldn't find attachment for digest: %v", digest) } response := msg.Response() @@ -468,7 +468,7 @@ func (btr *BlipTesterReplicator) initHandlers(btc *BlipTesterClient) { btr.bt.blipContext.DefaultHandler = func(msg *blip.Message) { btr.storeMessage(msg) - base.PanicfCtx(context.TODO(), "Unknown profile: %s caught by client DefaultHandler - msg: %#v", msg.Profile(), msg) + base.PanicfCtx(ctx, "Unknown profile: %s caught by client DefaultHandler - msg: %#v", msg.Profile(), msg) } } @@ -477,6 +477,8 @@ func (btc *BlipTesterCollectionClient) saveAttachment(_, base64data string) (dat btc.attachmentsLock.Lock() defer btc.attachmentsLock.Unlock() + ctx := base.DatabaseLogCtx(base.TestCtx(btc.parent.rt.TB), btc.parent.rt.GetDatabase().Name, nil) + data, err := base64.StdEncoding.DecodeString(base64data) if err != nil { return 0, "", err @@ -484,7 +486,7 @@ func (btc *BlipTesterCollectionClient) saveAttachment(_, base64data string) (dat digest = db.Sha1DigestKey(data) if _, found := btc.attachments[digest]; found { - base.InfofCtx(context.TODO(), base.KeySync, "attachment with digest %s already exists", digest) + base.InfofCtx(ctx, base.KeySync, "attachment with digest %s already exists", digest) } else { btc.attachments[digest] = data } @@ -514,8 +516,9 @@ func (btc *BlipTesterCollectionClient) updateLastReplicatedRev(docID, revID stri return } - currentGen, _ := db.ParseRevID(currentRevID) - incomingGen, _ := db.ParseRevID(revID) + ctx := base.TestCtx(btc.parent.rt.TB) + currentGen, _ := db.ParseRevID(ctx, currentRevID) + incomingGen, _ := db.ParseRevID(ctx, revID) if incomingGen > currentGen { btc.lastReplicatedRev[docID] = revID } @@ -787,7 +790,8 @@ func (btc *BlipTesterCollectionClient) PushRev(docID, parentRev string, body []b // PushRevWithHistory creates a revision on the client with history, and immediately sends a changes request for it. func (btc *BlipTesterCollectionClient) PushRevWithHistory(docID, parentRev string, body []byte, revCount, prunedRevCount int) (revID string, err error) { - parentRevGen, _ := db.ParseRevID(parentRev) + ctx := base.DatabaseLogCtx(base.TestCtx(btc.parent.rt.TB), btc.parent.rt.GetDatabase().Name, nil) + parentRevGen, _ := db.ParseRevID(ctx, parentRev) revGen := parentRevGen + revCount + prunedRevCount var revisionHistory []string @@ -864,7 +868,7 @@ func (btc *BlipTesterCollectionClient) PushRevWithHistory(docID, parentRev strin btc.addCollectionProperty(revRequest) if btc.parent.ClientDeltas && proposeChangesResponse.Properties[db.ProposeChangesResponseDeltas] == "true" { - base.DebugfCtx(context.TODO(), base.KeySync, "Sending deltas from test client") + base.DebugfCtx(ctx, base.KeySync, "Sending deltas from test client") var parentDocJSON, newDocJSON db.Body err := parentDocJSON.Unmarshal(parentDocBody) if err != nil { @@ -883,7 +887,7 @@ func (btc *BlipTesterCollectionClient) PushRevWithHistory(docID, parentRev strin revRequest.Properties[db.RevMessageDeltaSrc] = parentRev body = delta } else { - base.DebugfCtx(context.TODO(), base.KeySync, "Not sending deltas from test client") + base.DebugfCtx(ctx, base.KeySync, "Not sending deltas from test client") } revRequest.SetBody(body) @@ -907,7 +911,8 @@ func (btc *BlipTesterCollectionClient) PushRevWithHistory(docID, parentRev strin } func (btc *BlipTesterCollectionClient) StoreRevOnClient(docID, revID string, body []byte) error { - revGen, _ := db.ParseRevID(revID) + ctx := base.DatabaseLogCtx(base.TestCtx(btc.parent.rt.TB), btc.parent.rt.GetDatabase().Name, nil) + revGen, _ := db.ParseRevID(ctx, revID) newBody, err := btc.ProcessInlineAttachments(body, revGen) if err != nil { return err diff --git a/rest/blip_stats_test.go b/rest/blip_stats_test.go index de640ff761..10a87d8438 100644 --- a/rest/blip_stats_test.go +++ b/rest/blip_stats_test.go @@ -34,7 +34,7 @@ func waitForStatGreaterThan(t *testing.T, getStatFunc func() int64, expected int return stat <= expected, nil, val } // wait for up to 20 seconds for the stat to meet the expected value - err, val := base.RetryLoop("waitForStatGreaterThan retry loop", workerFunc, base.CreateSleeperFunc(200, 100)) + err, val := base.RetryLoop(base.TestCtx(t), "waitForStatGreaterThan retry loop", workerFunc, base.CreateSleeperFunc(200, 100)) require.NoError(t, err) valInt64, ok := val.(int64) require.True(t, ok) diff --git a/rest/blip_sync.go b/rest/blip_sync.go index a698b4bb1e..e6bfa2abc5 100644 --- a/rest/blip_sync.go +++ b/rest/blip_sync.go @@ -54,7 +54,7 @@ func (h *handler) handleBLIPSync() error { } // Overwrite the existing logging context with the blip context ID - h.rqCtx = base.LogContextWith(h.ctx(), &base.LogContext{CorrelationID: base.FormatBlipContextID(blipContext.ID)}) + h.rqCtx = base.CorrelationIDLogCtx(h.ctx(), base.FormatBlipContextID(blipContext.ID)) // Create a new BlipSyncContext attached to the given blipContext. ctx := db.NewBlipSyncContext(h.rqCtx, blipContext, h.db, h.formatSerialNumber(), db.BlipSyncStatsForCBL(h.db.DbStats)) diff --git a/rest/bootstrap_test.go b/rest/bootstrap_test.go index dfd8251e06..efd33400a1 100644 --- a/rest/bootstrap_test.go +++ b/rest/bootstrap_test.go @@ -43,7 +43,7 @@ func TestBootstrapRESTAPISetup(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -105,7 +105,7 @@ func TestBootstrapRESTAPISetup(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) defer func() { sc.Close(ctx) require.NoError(t, <-serverErr) @@ -162,7 +162,7 @@ func TestBootstrapDuplicateCollections(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -209,7 +209,7 @@ func TestBootstrapDuplicateDatabase(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -285,7 +285,7 @@ func DevTestFetchConfigManual(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Sleep to wait for bucket polling iterations, or allow manual modification to server accessibility diff --git a/rest/bulk_api.go b/rest/bulk_api.go index a3150280bd..8ee7d470ca 100644 --- a/rest/bulk_api.go +++ b/rest/bulk_api.go @@ -199,7 +199,7 @@ func (h *handler) handleAllDocs() error { options.Limit = h.getIntQuery("limit", 0) // Now it's time to actually write the response! - lastSeq, _ := h.db.LastSequence() + lastSeq, _ := h.db.LastSequence(h.ctx()) h.setHeader("Content-Type", "application/json") // response.Write below would set Status OK implicitly. We manually do it here to ensure that our handler knows // that the header has been written to, meaning we can prevent it from attempting to set the header again later on. @@ -315,7 +315,7 @@ func (h *handler) handleDumpChannel() error { since := h.getIntQuery("since", 0) base.InfofCtx(h.ctx(), base.KeyHTTP, "Dump channel %q", base.UD(channelName)) - chanLog, _ := h.collection.GetChangeLog(ch.NewID(channelName, h.collection.GetCollectionID()), since) + chanLog, _ := h.collection.GetChangeLog(h.ctx(), ch.NewID(channelName, h.collection.GetCollectionID()), since) if chanLog == nil { return base.HTTPErrorf(http.StatusNotFound, "no such channel") } @@ -510,7 +510,7 @@ func (h *handler) handleBulkDocs() error { docid, revid, _, err = h.collection.Post(h.ctx(), doc) } } else { - revisions := db.ParseRevisions(doc) + revisions := db.ParseRevisions(h.ctx(), doc) if revisions == nil { err = base.HTTPErrorf(http.StatusBadRequest, "Bad _revisions") } else { diff --git a/rest/bytes_read_public_api_test.go b/rest/bytes_read_public_api_test.go index ec67557e5f..67cdb4c6c1 100644 --- a/rest/bytes_read_public_api_test.go +++ b/rest/bytes_read_public_api_test.go @@ -29,10 +29,9 @@ func TestBytesReadDocOperations(t *testing.T) { // create a user to authenticate as for public api calls and assert the stat hasn't incremented as a result rt.CreateUser("greg", []string{"ABC"}) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, 0) - require.True(t, ok) // use public api to put a doc through SGW then assert the stat has increased input := `{"foo":"bar", "channels":["ABC"]}` @@ -40,34 +39,30 @@ func TestBytesReadDocOperations(t *testing.T) { resp := rt.SendUserRequest(http.MethodPut, "/{{.keyspace}}/doc1", input, "greg") RequireStatus(t, resp, http.StatusCreated) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) // send admin request assert that the public rest count doesn't increase resp = rt.SendAdminRequest(http.MethodGet, "/{{.keyspace}}/doc1", "") RequireStatus(t, resp, http.StatusOK) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) // send user request that has empty body, assert the stat doesn't increase resp = rt.SendUserRequest(http.MethodGet, "/{{.keyspace}}/doc1", "", "greg") RequireStatus(t, resp, http.StatusOK) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) // assert blipsync connection doesn't increment stat resp = rt.SendUserRequest(http.MethodGet, "/{{.db}}/_blipsync", "", "greg") RequireStatus(t, resp, http.StatusUpgradeRequired) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) srv := httptest.NewServer(rt.TestMetricsHandler()) defer srv.Close() @@ -78,19 +73,17 @@ func TestBytesReadDocOperations(t *testing.T) { require.NoError(t, err) assert.Equal(t, http.StatusOK, response.StatusCode) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) // test public endpoint but one that doesn't access a db and assert that doesn't increment stat resp = rt.SendUserRequest(http.MethodGet, "/", "", "greg") RequireStatus(t, resp, http.StatusOK) // assert the stat doesn't increment - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) // send another public request (this time POST) to check stat increments, but check it increments by correct bytes value input = fmt.Sprint(`{"foo":"bar", "channels":["ABC"]}`) @@ -100,10 +93,9 @@ func TestBytesReadDocOperations(t *testing.T) { RequireStatus(t, resp, http.StatusOK) cumulativeBytes := len(inputBytes) + len(inputBytes2) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(cumulativeBytes)) - require.True(t, ok) } func TestBytesReadChanges(t *testing.T) { @@ -112,10 +104,9 @@ func TestBytesReadChanges(t *testing.T) { // create a user and assert this doesn't increase the bytes read stat rt.CreateUser("alice", nil) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, 0) - require.True(t, ok) // take the bytes of the body we will pass into request and perform changes POST request changesJSON := `{"style":"all_docs", "timeout":6000, "feed":"longpoll", "limit":50, "since":"0"}` @@ -124,10 +115,9 @@ func TestBytesReadChanges(t *testing.T) { RequireStatus(t, resp, http.StatusOK) // assert the stat has increased by the number of bytes passed into request - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(byteArrayChangesBody))) - require.True(t, ok) } @@ -161,10 +151,9 @@ func TestBytesReadPutAttachment(t *testing.T) { RequireStatus(t, resp, http.StatusCreated) // assert the stat has increased by the attachment endpoint input - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(byteArrayAttachmentBody))) - require.True(t, ok) // test incorrect user still increments count resp = rt.SendUserRequestWithHeaders("PUT", "/{{.keyspace}}/doc1/attach1?rev="+revid, attachmentBody, reqHeaders, "bob", "letmein") @@ -172,10 +161,9 @@ func TestBytesReadPutAttachment(t *testing.T) { newStatNum := len(byteArrayAttachmentBody) * 2 - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(newStatNum)) - require.True(t, ok) } @@ -206,20 +194,18 @@ func TestBytesReadRevDiff(t *testing.T) { RequireStatus(t, resp, http.StatusOK) // assert the stat has increased by the bytes above - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) newStat := len(inputBytes) * 2 // now try failed auth resp = rt.SendUserRequest(http.MethodPost, "/{{.keyspace}}/_revs_diff", input, "bob") RequireStatus(t, resp, http.StatusUnauthorized) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(newStat)) - require.True(t, ok) } @@ -248,20 +234,18 @@ func TestBytesReadAllDocs(t *testing.T) { RequireStatus(t, resp, http.StatusOK) // assert the stat has increased by the bytes length - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) newStat := len(inputBytes) * 2 // now try failed auth resp = rt.SendUserRequest(http.MethodPost, "/{{.keyspace}}/_all_docs", input, "bob") RequireStatus(t, resp, http.StatusUnauthorized) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(newStat)) - require.True(t, ok) } @@ -277,20 +261,18 @@ func TestBytesReadBulkDocs(t *testing.T) { resp := rt.SendUserRequest("POST", "/{{.keyspace}}/_bulk_docs", input, "alice") RequireStatus(t, resp, http.StatusCreated) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) newStat := len(inputBytes) * 2 // now try failed auth resp = rt.SendUserRequest("POST", "/{{.keyspace}}/_bulk_docs", input, "bob") RequireStatus(t, resp, http.StatusUnauthorized) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(newStat)) - require.True(t, ok) } @@ -314,20 +296,18 @@ func TestBytesReadBulkGet(t *testing.T) { RequireStatus(t, resp, http.StatusOK) // assert the stat has increased by the length of byte array - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) newStat := len(inputBytes) * 2 // now try failed auth resp = rt.SendUserRequest(http.MethodPost, "/{{.keyspace}}/_bulk_get", input, "bob") RequireStatus(t, resp, http.StatusUnauthorized) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(newStat)) - require.True(t, ok) } @@ -344,20 +324,18 @@ func TestBytesReadLocalDocPut(t *testing.T) { RequireStatus(t, resp, http.StatusCreated) // assert the stat is increased by the correct amount - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) newStat := len(inputBytes) * 2 // now try failed auth resp = rt.SendUserRequest(http.MethodPut, "/{{.keyspace}}/_local/doc1", input, "bob") RequireStatus(t, resp, http.StatusUnauthorized) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(newStat)) - require.True(t, ok) } func TestBytesReadPOSTSession(t *testing.T) { @@ -373,20 +351,18 @@ func TestBytesReadPOSTSession(t *testing.T) { RequireStatus(t, resp, http.StatusOK) // assert the stat is increased by the correct amount - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) newStat := len(inputBytes) * 2 // now try failed auth resp = rt.SendUserRequest(http.MethodPost, "/{{.db}}/_session", input, "bob") RequireStatus(t, resp, http.StatusUnauthorized) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(newStat)) - require.True(t, ok) } func TestBytesReadAuthFailed(t *testing.T) { @@ -404,10 +380,9 @@ func TestBytesReadAuthFailed(t *testing.T) { RequireStatus(t, resp, http.StatusUnauthorized) // assert the stat has still increased by the bytes of the body passed into request - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) } @@ -437,11 +412,10 @@ func TestBytesReadGzipRequest(t *testing.T) { resp := rt.Send(rq) RequireStatus(t, resp, http.StatusCreated) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { fmt.Println(rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value()) return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) } @@ -471,10 +445,9 @@ func TestPutDBBytesRead(t *testing.T) { RequireStatus(t, resp, http.StatusCreated) // assert the stat hasn't increased (admin request doesn't effect count) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, 0) - require.True(t, ok) } @@ -490,10 +463,9 @@ func TestOfflineDBBytesRead(t *testing.T) { resp = rt.SendUserRequest(http.MethodGet, "/{{.db}}/", "", "alice") RequireStatus(t, resp, http.StatusOK) - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, 0) - require.True(t, ok) // try adding body to get request input := `{"random": "body"}` @@ -501,9 +473,8 @@ func TestOfflineDBBytesRead(t *testing.T) { resp = rt.SendUserRequest(http.MethodGet, "/{{.db}}/", input, "alice") RequireStatus(t, resp, http.StatusOK) - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() }, int64(len(inputBytes))) - require.True(t, ok) } diff --git a/rest/changes_api.go b/rest/changes_api.go index 58c2b990b8..c85b38a97a 100644 --- a/rest/changes_api.go +++ b/rest/changes_api.go @@ -319,7 +319,7 @@ func (h *handler) handleChanges() error { h.db.DatabaseContext.DbStats.Database().NumReplicationsTotal.Add(1) defer h.db.DatabaseContext.DbStats.Database().NumReplicationsActive.Add(-1) - changesCtx, changesCtxCancel := context.WithCancel(context.Background()) + changesCtx, changesCtxCancel := context.WithCancel(h.ctx()) options.ChangesCtx = changesCtx forceClose := false @@ -353,7 +353,7 @@ func (h *handler) handleChanges() error { if h.user != nil { user = h.user.Name() } - h.db.DatabaseContext.NotifyTerminatedChanges(user) + h.db.DatabaseContext.NotifyTerminatedChanges(h.ctx(), user) } return err @@ -511,7 +511,7 @@ func (h *handler) sendContinuousChangesByWebSocket(inChannels base.Set, options // Read changes-feed options from an initial incoming WebSocket message in JSON format: var wsoptions db.ChangesOptions var compress bool - if msg, err := readWebSocketMessage(conn); err != nil { + if msg, err := readWebSocketMessage(h.ctx(), conn); err != nil { return } else { var channelNames []string @@ -650,12 +650,12 @@ func (h *handler) readChangesOptionsFromJSON(jsonData []byte) (feed string, opti } // Helper function to read a complete message from a WebSocket -func readWebSocketMessage(conn *websocket.Conn) ([]byte, error) { +func readWebSocketMessage(ctx context.Context, conn *websocket.Conn) ([]byte, error) { var message []byte if err := websocket.Message.Receive(conn, &message); err != nil { if err != io.EOF { - base.WarnfCtx(context.TODO(), "Error reading initial websocket message: %v", err) + base.WarnfCtx(ctx, "Error reading initial websocket message: %v", err) return nil, err } } diff --git a/rest/changes_test.go b/rest/changes_test.go index c448525bff..39d96b1723 100644 --- a/rest/changes_test.go +++ b/rest/changes_test.go @@ -236,7 +236,7 @@ func TestWebhookWinningRevChangedEvent(t *testing.T) { res := rt.SendAdminRequest("PUT", "/{{.keyspace}}/doc1", `{"foo":"bar"}`) RequireStatus(t, res, http.StatusCreated) rev1 := RespRevID(t, res) - _, rev1Hash := db.ParseRevID(rev1) + _, rev1Hash := db.ParseRevID(rt.Context(), rev1) // push winning branch wg.Add(2) diff --git a/rest/changestest/changes_api_test.go b/rest/changestest/changes_api_test.go index 1154cc6def..a14815180a 100644 --- a/rest/changestest/changes_api_test.go +++ b/rest/changestest/changes_api_test.go @@ -9,7 +9,6 @@ package changestest import ( - "context" "encoding/json" "errors" "fmt" @@ -1002,9 +1001,9 @@ func TestChangesLoopingWhenLowSequence(t *testing.T) { rest.RequireStatus(t, response, 201) // Simulate seq 3 and 4 being delayed - write 1,2,5,6 - WriteDirect([]string{"PBS"}, 2, collection) - WriteDirect([]string{"PBS"}, 5, collection) - WriteDirect([]string{"PBS"}, 6, collection) + WriteDirect(t, []string{"PBS"}, 2, collection) + WriteDirect(t, []string{"PBS"}, 5, collection) + WriteDirect(t, []string{"PBS"}, 6, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 6)) // Check the _changes feed: @@ -1029,7 +1028,7 @@ func TestChangesLoopingWhenLowSequence(t *testing.T) { require.Len(t, changes.Results, 0) // Send a missing doc - low sequence should move to 3 - WriteDirect([]string{"PBS"}, 3, collection) + WriteDirect(t, []string{"PBS"}, 3, collection) require.NoError(t, rt.WaitForSequence(3)) // WaitForSequence doesn't wait for low sequence to be updated on each channel - additional delay to ensure @@ -1043,7 +1042,7 @@ func TestChangesLoopingWhenLowSequence(t *testing.T) { require.Len(t, changes.Results, 3) // Send a later doc - low sequence still 3, high sequence goes to 7 - WriteDirect([]string{"PBS"}, 7, collection) + WriteDirect(t, []string{"PBS"}, 7, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 7)) // Send another changes request with the same since ("2::6") to ensure we see data once there are changes @@ -1094,10 +1093,10 @@ func TestChangesLoopingWhenLowSequenceOneShotUser(t *testing.T) { rest.RequireStatus(t, response, 201) // Simulate 4 non-skipped writes (seq 2,3,4,5) - WriteDirect([]string{"PBS"}, 2, collection) - WriteDirect([]string{"PBS"}, 3, collection) - WriteDirect([]string{"PBS"}, 4, collection) - WriteDirect([]string{"PBS"}, 5, collection) + WriteDirect(t, []string{"PBS"}, 2, collection) + WriteDirect(t, []string{"PBS"}, 3, collection) + WriteDirect(t, []string{"PBS"}, 4, collection) + WriteDirect(t, []string{"PBS"}, 5, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 5)) // Check the _changes feed: @@ -1114,10 +1113,10 @@ func TestChangesLoopingWhenLowSequenceOneShotUser(t *testing.T) { assert.Equal(t, "5", changes.Last_Seq) // Skip sequence 6, write docs 7-10 - WriteDirect([]string{"PBS"}, 7, collection) - WriteDirect([]string{"PBS"}, 8, collection) - WriteDirect([]string{"PBS"}, 9, collection) - WriteDirect([]string{"PBS"}, 10, collection) + WriteDirect(t, []string{"PBS"}, 7, collection) + WriteDirect(t, []string{"PBS"}, 8, collection) + WriteDirect(t, []string{"PBS"}, 9, collection) + WriteDirect(t, []string{"PBS"}, 10, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 10)) // Send another changes request with the last_seq received from the last changes ("5") @@ -1130,8 +1129,8 @@ func TestChangesLoopingWhenLowSequenceOneShotUser(t *testing.T) { assert.Equal(t, "5::10", changes.Last_Seq) // Write a few more docs - WriteDirect([]string{"PBS"}, 11, collection) - WriteDirect([]string{"PBS"}, 12, collection) + WriteDirect(t, []string{"PBS"}, 11, collection) + WriteDirect(t, []string{"PBS"}, 12, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 12)) // Send another changes request with the last_seq received from the last changes ("5") @@ -1144,8 +1143,8 @@ func TestChangesLoopingWhenLowSequenceOneShotUser(t *testing.T) { assert.Equal(t, "5::12", changes.Last_Seq) // Write another doc, then the skipped doc - both should be sent, last_seq should move to 13 - WriteDirect([]string{"PBS"}, 13, collection) - WriteDirect([]string{"PBS"}, 6, collection) + WriteDirect(t, []string{"PBS"}, 13, collection) + WriteDirect(t, []string{"PBS"}, 6, collection) require.NoError(t, rt.WaitForSequence(13)) changesJSON = fmt.Sprintf(`{"since":"%s"}`, changes.Last_Seq) @@ -1224,11 +1223,11 @@ func TestChangesLoopingWhenLowSequenceOneShotAdmin(t *testing.T) { ctx := rt.Context() // Simulate 5 non-skipped writes (seq 1,2,3,4,5) - WriteDirect([]string{"PBS"}, 1, collection) - WriteDirect([]string{"PBS"}, 2, collection) - WriteDirect([]string{"PBS"}, 3, collection) - WriteDirect([]string{"PBS"}, 4, collection) - WriteDirect([]string{"PBS"}, 5, collection) + WriteDirect(t, []string{"PBS"}, 1, collection) + WriteDirect(t, []string{"PBS"}, 2, collection) + WriteDirect(t, []string{"PBS"}, 3, collection) + WriteDirect(t, []string{"PBS"}, 4, collection) + WriteDirect(t, []string{"PBS"}, 5, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 5)) // Check the _changes feed: var changes struct { @@ -1244,10 +1243,10 @@ func TestChangesLoopingWhenLowSequenceOneShotAdmin(t *testing.T) { assert.Equal(t, "5", changes.Last_Seq) // Skip sequence 6, write docs 7-10 - WriteDirect([]string{"PBS"}, 7, collection) - WriteDirect([]string{"PBS"}, 8, collection) - WriteDirect([]string{"PBS"}, 9, collection) - WriteDirect([]string{"PBS"}, 10, collection) + WriteDirect(t, []string{"PBS"}, 7, collection) + WriteDirect(t, []string{"PBS"}, 8, collection) + WriteDirect(t, []string{"PBS"}, 9, collection) + WriteDirect(t, []string{"PBS"}, 10, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 10)) // Send another changes request with the last_seq received from the last changes ("5") @@ -1260,8 +1259,8 @@ func TestChangesLoopingWhenLowSequenceOneShotAdmin(t *testing.T) { assert.Equal(t, "5::10", changes.Last_Seq) // Write a few more docs - WriteDirect([]string{"PBS"}, 11, collection) - WriteDirect([]string{"PBS"}, 12, collection) + WriteDirect(t, []string{"PBS"}, 11, collection) + WriteDirect(t, []string{"PBS"}, 12, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 12)) // Send another changes request with the last_seq received from the last changes ("5") @@ -1274,8 +1273,8 @@ func TestChangesLoopingWhenLowSequenceOneShotAdmin(t *testing.T) { assert.Equal(t, "5::12", changes.Last_Seq) // Write another doc, then the skipped doc - both should be sent, last_seq should move to 13 - WriteDirect([]string{"PBS"}, 13, collection) - WriteDirect([]string{"PBS"}, 6, collection) + WriteDirect(t, []string{"PBS"}, 13, collection) + WriteDirect(t, []string{"PBS"}, 6, collection) require.NoError(t, rt.WaitForSequence(13)) changesJSON = fmt.Sprintf(`{"since":"%s"}`, changes.Last_Seq) @@ -1360,10 +1359,10 @@ func TestChangesLoopingWhenLowSequenceLongpollUser(t *testing.T) { rest.RequireStatus(t, response, 201) // Simulate 4 non-skipped writes (seq 2,3,4,5) - WriteDirect([]string{"PBS"}, 2, collection) - WriteDirect([]string{"PBS"}, 3, collection) - WriteDirect([]string{"PBS"}, 4, collection) - WriteDirect([]string{"PBS"}, 5, collection) + WriteDirect(t, []string{"PBS"}, 2, collection) + WriteDirect(t, []string{"PBS"}, 3, collection) + WriteDirect(t, []string{"PBS"}, 4, collection) + WriteDirect(t, []string{"PBS"}, 5, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 5)) // Check the _changes feed: @@ -1380,10 +1379,10 @@ func TestChangesLoopingWhenLowSequenceLongpollUser(t *testing.T) { assert.Equal(t, "5", changes.Last_Seq) // Skip sequence 6, write docs 7-10 - WriteDirect([]string{"PBS"}, 7, collection) - WriteDirect([]string{"PBS"}, 8, collection) - WriteDirect([]string{"PBS"}, 9, collection) - WriteDirect([]string{"PBS"}, 10, collection) + WriteDirect(t, []string{"PBS"}, 7, collection) + WriteDirect(t, []string{"PBS"}, 8, collection) + WriteDirect(t, []string{"PBS"}, 9, collection) + WriteDirect(t, []string{"PBS"}, 10, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 10)) // Send another changes request with the last_seq received from the last changes ("5") @@ -1396,8 +1395,8 @@ func TestChangesLoopingWhenLowSequenceLongpollUser(t *testing.T) { assert.Equal(t, "5::10", changes.Last_Seq) // Write a few more docs - WriteDirect([]string{"PBS"}, 11, collection) - WriteDirect([]string{"PBS"}, 12, collection) + WriteDirect(t, []string{"PBS"}, 11, collection) + WriteDirect(t, []string{"PBS"}, 12, collection) require.NoError(t, collection.WaitForSequenceNotSkipped(ctx, 12)) // Send another changes request with the last_seq received from the last changes ("5") @@ -1429,7 +1428,7 @@ func TestChangesLoopingWhenLowSequenceLongpollUser(t *testing.T) { require.NoError(t, rt.GetDatabase().WaitForCaughtUp(caughtUpCount+1)) // Write the skipped doc, wait for longpoll to return - WriteDirect([]string{"PBS"}, 6, collection) + WriteDirect(t, []string{"PBS"}, 6, collection) // WriteDirect(testDb, []string{"PBS"}, 13) longpollWg.Wait() @@ -4286,12 +4285,12 @@ func waitForCompactStopped(dbc *db.DatabaseContext) error { // ////// HELPERS: -func WriteDirect(channelArray []string, sequence uint64, collection *db.DatabaseCollection) { +func WriteDirect(t *testing.T, channelArray []string, sequence uint64, collection *db.DatabaseCollection) { docId := fmt.Sprintf("doc-%v", sequence) - WriteDirectWithKey(docId, channelArray, sequence, collection) + WriteDirectWithKey(t, docId, channelArray, sequence, collection) } -func WriteDirectWithKey(key string, channelArray []string, sequence uint64, collection *db.DatabaseCollection) { +func WriteDirectWithKey(t *testing.T, key string, channelArray []string, sequence uint64, collection *db.DatabaseCollection) { if base.TestUseXattrs() { panic(fmt.Sprintf("WriteDirectWithKey() cannot be used in tests that are xattr enabled")) @@ -4317,8 +4316,6 @@ func WriteDirectWithKey(key string, channelArray []string, sequence uint64, coll dataStore := collection.GetCollectionDatastore() _, err := dataStore.Add(key, 0, db.Body{base.SyncPropertyName: syncData, "key": key}) - if err != nil { - base.PanicfCtx(context.TODO(), "Error while add ket to bucket: %v", err) - } + require.NoError(t, err) } diff --git a/rest/changestest/main_test.go b/rest/changestest/main_test.go index 84f952bfa2..51d8e2857e 100644 --- a/rest/changestest/main_test.go +++ b/rest/changestest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package changestest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,6 +19,7 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/rest/config.go b/rest/config.go index a4a0651ade..1c02f270a3 100644 --- a/rest/config.go +++ b/rest/config.go @@ -361,7 +361,7 @@ func (dbConfig *DbConfig) setDatabaseCredentials(credentials base.CredentialsCon } // setup populates fields in the dbConfig -func (dbConfig *DbConfig) setup(dbName string, bootstrapConfig BootstrapConfig, dbCredentials, bucketCredentials *base.CredentialsConfig, forcePerBucketAuth bool) error { +func (dbConfig *DbConfig) setup(ctx context.Context, dbName string, bootstrapConfig BootstrapConfig, dbCredentials, bucketCredentials *base.CredentialsConfig, forcePerBucketAuth bool) error { dbConfig.Name = dbName if dbConfig.Bucket == nil { dbConfig.Bucket = &dbConfig.Name @@ -406,7 +406,7 @@ func (dbConfig *DbConfig) setup(dbName string, bootstrapConfig BootstrapConfig, // Load Sync Function. if dbConfig.Sync != nil { - sync, err := loadJavaScript(*dbConfig.Sync, insecureSkipVerify) + sync, err := loadJavaScript(ctx, *dbConfig.Sync, insecureSkipVerify) if err != nil { return &JavaScriptLoadError{ JSLoadType: SyncFunction, @@ -419,7 +419,7 @@ func (dbConfig *DbConfig) setup(dbName string, bootstrapConfig BootstrapConfig, // Load Import Filter Function. if dbConfig.ImportFilter != nil { - importFilter, err := loadJavaScript(*dbConfig.ImportFilter, insecureSkipVerify) + importFilter, err := loadJavaScript(ctx, *dbConfig.ImportFilter, insecureSkipVerify) if err != nil { return &JavaScriptLoadError{ JSLoadType: ImportFilter, @@ -433,7 +433,7 @@ func (dbConfig *DbConfig) setup(dbName string, bootstrapConfig BootstrapConfig, // Load Conflict Resolution Function. for _, rc := range dbConfig.Replications { if rc.ConflictResolutionFn != "" { - conflictResolutionFn, err := loadJavaScript(rc.ConflictResolutionFn, insecureSkipVerify) + conflictResolutionFn, err := loadJavaScript(ctx, rc.ConflictResolutionFn, insecureSkipVerify) if err != nil { return &JavaScriptLoadError{ JSLoadType: ConflictResolver, @@ -452,8 +452,8 @@ func (dbConfig *DbConfig) setup(dbName string, bootstrapConfig BootstrapConfig, // If the specified path does not qualify for a valid file or an URI, it returns the input path // as-is with the assumption that it is an inline JavaScript source. Returns error if there is // any failure in reading the JavaScript file or URI. -func loadJavaScript(path string, insecureSkipVerify bool) (js string, err error) { - rc, err := readFromPath(path, insecureSkipVerify) +func loadJavaScript(ctx context.Context, path string, insecureSkipVerify bool) (js string, err error) { + rc, err := readFromPath(ctx, path, insecureSkipVerify) if errors.Is(err, ErrPathNotFound) { // If rc is nil and readFromPath returns no error, treat the // the given path as an inline JavaScript and return it as-is. @@ -520,10 +520,10 @@ var ErrPathNotFound = errors.New("path not found") // readFromPath creates a ReadCloser from the given path. The path must be either a valid file // or an HTTP/HTTPS endpoint. Returns an error if there is any failure in building ReadCloser. -func readFromPath(path string, insecureSkipVerify bool) (rc io.ReadCloser, err error) { +func readFromPath(ctx context.Context, path string, insecureSkipVerify bool) (rc io.ReadCloser, err error) { messageFormat := "Loading content from [%s] ..." if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { - base.InfofCtx(context.Background(), base.KeyAll, messageFormat, path) + base.InfofCtx(ctx, base.KeyAll, messageFormat, path) client := base.GetHttpClient(insecureSkipVerify) resp, err := client.Get(path) if err != nil { @@ -534,7 +534,7 @@ func readFromPath(path string, insecureSkipVerify bool) (rc io.ReadCloser, err e } rc = resp.Body } else if base.FileExists(path) { - base.InfofCtx(context.Background(), base.KeyAll, messageFormat, path) + base.InfofCtx(ctx, base.KeyAll, messageFormat, path) rc, err = os.Open(path) if err != nil { return nil, err @@ -545,7 +545,7 @@ func readFromPath(path string, insecureSkipVerify bool) (rc io.ReadCloser, err e return rc, nil } -func (dbConfig *DbConfig) AutoImportEnabled() (bool, error) { +func (dbConfig *DbConfig) AutoImportEnabled(ctx context.Context) (bool, error) { if dbConfig.AutoImport == nil { if !dbConfig.UseXattrs() { return false, nil @@ -559,7 +559,7 @@ func (dbConfig *DbConfig) AutoImportEnabled() (bool, error) { str, ok := dbConfig.AutoImport.(string) if ok && str == "continuous" { - base.WarnfCtx(context.Background(), `Using deprecated config value for "import_docs": "continuous". Use "import_docs": true instead.`) + base.WarnfCtx(ctx, `Using deprecated config value for "import_docs": "continuous". Use "import_docs": true instead.`) return true, nil } @@ -729,7 +729,7 @@ func (dbConfig *DbConfig) validateVersion(ctx context.Context, isEnterpriseEditi } // Import validation - autoImportEnabled, err := dbConfig.AutoImportEnabled() + autoImportEnabled, err := dbConfig.AutoImportEnabled(ctx) if err != nil { multiError = multiError.Append(err) } @@ -1069,7 +1069,7 @@ func (dbConfig *DbConfig) UseXattrs() bool { return base.DefaultUseXattrs } -func (dbConfig *DbConfig) Redacted() (*DbConfig, error) { +func (dbConfig *DbConfig) Redacted(ctx context.Context) (*DbConfig, error) { var config DbConfig err := base.DeepCopyInefficient(&config, dbConfig) @@ -1077,12 +1077,12 @@ func (dbConfig *DbConfig) Redacted() (*DbConfig, error) { return nil, err } - err = config.redactInPlace() + err = config.redactInPlace(ctx) return &config, err } // redactInPlace modifies the given config to redact the fields inside it. -func (config *DbConfig) redactInPlace() error { +func (config *DbConfig) redactInPlace(ctx context.Context) error { if config.Password != "" { config.Password = base.RedactedStr @@ -1095,21 +1095,21 @@ func (config *DbConfig) redactInPlace() error { } for i, _ := range config.Replications { - config.Replications[i] = config.Replications[i].Redacted() + config.Replications[i] = config.Replications[i].Redacted(ctx) } return nil } // DecodeAndSanitiseConfig will sanitise a config from an io.Reader and unmarshal it into the given config parameter. -func DecodeAndSanitiseConfig(r io.Reader, config interface{}, disallowUnknownFields bool) (err error) { +func DecodeAndSanitiseConfig(ctx context.Context, r io.Reader, config interface{}, disallowUnknownFields bool) (err error) { b, err := io.ReadAll(r) if err != nil { return err } // Expand environment variables. - b, err = expandEnv(b) + b, err = expandEnv(ctx, b) if err != nil { return err } @@ -1127,14 +1127,14 @@ func DecodeAndSanitiseConfig(r io.Reader, config interface{}, disallowUnknownFie // current environment variables. The replacement is case-sensitive. References // to undefined variables will result in an error. A default value can // be given by using the form ${var:-default value}. -func expandEnv(config []byte) (value []byte, err error) { +func expandEnv(ctx context.Context, config []byte) (value []byte, err error) { var multiError *base.MultiError val := []byte(os.Expand(string(config), func(key string) string { if key == "$" { - base.DebugfCtx(context.Background(), base.KeyConfig, "Skipping environment variable expansion: %s", key) + base.DebugfCtx(ctx, base.KeyConfig, "Skipping environment variable expansion: %s", key) return key } - val, err := envDefaultExpansion(key, os.Getenv) + val, err := envDefaultExpansion(ctx, key, os.Getenv) if err != nil { multiError = multiError.Append(err) } @@ -1155,19 +1155,19 @@ func (e ErrEnvVarUndefined) Error() string { // envDefaultExpansion implements the ${foo:-bar} parameter expansion from // https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02 -func envDefaultExpansion(key string, getEnvFn func(string) string) (value string, err error) { +func envDefaultExpansion(ctx context.Context, key string, getEnvFn func(string) string) (value string, err error) { kvPair := strings.SplitN(key, ":-", 2) key = kvPair[0] value = getEnvFn(key) if value == "" && len(kvPair) == 2 { // Set value to the default. value = kvPair[1] - base.DebugfCtx(context.Background(), base.KeyConfig, "Replacing config environment variable '${%s}' with "+ + base.DebugfCtx(ctx, base.KeyConfig, "Replacing config environment variable '${%s}' with "+ "default value specified", key) } else if value == "" && len(kvPair) != 2 { return "", ErrEnvVarUndefined{key: key} } else { - base.DebugfCtx(context.Background(), base.KeyConfig, "Replacing config environment variable '${%s}'", key) + base.DebugfCtx(ctx, base.KeyConfig, "Replacing config environment variable '${%s}'", key) } return value, nil } @@ -1193,20 +1193,20 @@ func (sc *StartupConfig) SetupAndValidateLogging(ctx context.Context) (err error ) } -func SetMaxFileDescriptors(maxP *uint64) error { +func SetMaxFileDescriptors(ctx context.Context, maxP *uint64) error { maxFDs := DefaultMaxFileDescriptors if maxP != nil { maxFDs = *maxP } - _, err := base.SetMaxFileDescriptors(maxFDs) + _, err := base.SetMaxFileDescriptors(ctx, maxFDs) if err != nil { - base.ErrorfCtx(context.Background(), "Error setting MaxFileDescriptors to %d: %v", maxFDs, err) + base.ErrorfCtx(ctx, "Error setting MaxFileDescriptors to %d: %v", maxFDs, err) return err } return nil } -func (sc *ServerContext) Serve(config *StartupConfig, addr string, handler http.Handler) error { +func (sc *ServerContext) Serve(ctx context.Context, config *StartupConfig, addr string, handler http.Handler) error { http2Enabled := false if config.Unsupported.HTTP2 != nil && config.Unsupported.HTTP2.Enabled != nil { http2Enabled = *config.Unsupported.HTTP2.Enabled @@ -1215,6 +1215,7 @@ func (sc *ServerContext) Serve(config *StartupConfig, addr string, handler http. tlsMinVersion := GetTLSVersionFromString(&config.API.HTTPS.TLSMinimumVersion) serveFn, server, err := base.ListenAndServeHTTP( + ctx, addr, config.API.MaximumConnections, config.API.HTTPS.TLSCertPath, @@ -1243,7 +1244,7 @@ func (sc *ServerContext) addHTTPServer(s *http.Server) { } // Validate returns errors errors if invalid config is present -func (sc *StartupConfig) Validate(isEnterpriseEdition bool) (errorMessages error) { +func (sc *StartupConfig) Validate(ctx context.Context, isEnterpriseEdition bool) (errorMessages error) { var multiError *base.MultiError if sc.Bootstrap.Server == "" { multiError = multiError.Append(fmt.Errorf("a server must be provided in the Bootstrap configuration")) @@ -1280,7 +1281,7 @@ func (sc *StartupConfig) Validate(isEnterpriseEdition bool) (errorMessages error if sc.DatabaseCredentials != nil { for dbName, creds := range sc.DatabaseCredentials { if (creds.X509CertPath != "" || creds.X509KeyPath != "") && (creds.Username != "" || creds.Password != "") { - base.WarnfCtx(context.TODO(), "database %q in database_credentials cannot use both x509 and basic auth. Will use x509 only.", base.MD(dbName)) + base.WarnfCtx(ctx, "database %q in database_credentials cannot use both x509 and basic auth. Will use x509 only.", base.MD(dbName)) } } } @@ -1329,11 +1330,11 @@ func SetupServerContext(ctx context.Context, config *StartupConfig, persistentCo base.InfofCtx(ctx, base.KeyAll, "Logging: Console keys: %v", base.ConsoleLogKey().EnabledLogKeys()) base.InfofCtx(ctx, base.KeyAll, "Logging: Redaction level: %s", config.Logging.RedactionLevel) - if err := setGlobalConfig(config); err != nil { + if err := setGlobalConfig(ctx, config); err != nil { return nil, err } - if err := config.Validate(base.IsEnterpriseEdition()); err != nil { + if err := config.Validate(ctx, base.IsEnterpriseEdition()); err != nil { return nil, err } @@ -1442,7 +1443,7 @@ func (sc *ServerContext) migrateV30Configs(ctx context.Context) error { for _, bucketName := range buckets { var dbConfig DatabaseConfig - legacyCas, getErr := sc.BootstrapContext.Connection.GetMetadataDocument(bucketName, PersistentConfigKey30(groupID), &dbConfig) + legacyCas, getErr := sc.BootstrapContext.Connection.GetMetadataDocument(ctx, bucketName, PersistentConfigKey30(ctx, groupID), &dbConfig) if getErr == base.ErrNotFound { continue } else if getErr != nil { @@ -1458,7 +1459,7 @@ func (sc *ServerContext) migrateV30Configs(ctx context.Context) error { } return fmt.Errorf("Error migrating v3.0 config for bucket %s groupID %s: %w", base.MD(bucketName), base.MD(groupID), insertErr) } - removeErr := sc.BootstrapContext.Connection.DeleteMetadataDocument(bucketName, PersistentConfigKey30(groupID), legacyCas) + removeErr := sc.BootstrapContext.Connection.DeleteMetadataDocument(ctx, bucketName, PersistentConfigKey30(ctx, groupID), legacyCas) if removeErr != nil { base.InfofCtx(ctx, base.KeyConfig, "Failed to remove legacy config for database %s.", base.MD(dbConfig.Name)) } @@ -1494,7 +1495,7 @@ func (sc *ServerContext) fetchDatabase(ctx context.Context, dbName string) (foun // loop code moved to foreachDbConfig var cnf DatabaseConfig callback := func(bucket string) (exit bool, err error) { - cas, err := sc.BootstrapContext.GetConfig(bucket, sc.Config.Bootstrap.ConfigGroupID, dbName, &cnf) + cas, err := sc.BootstrapContext.GetConfig(ctx, bucket, sc.Config.Bootstrap.ConfigGroupID, dbName, &cnf) if err == base.ErrNotFound { base.DebugfCtx(ctx, base.KeyConfig, "%q did not contain config in group %q", bucket, sc.Config.Bootstrap.ConfigGroupID) return false, err @@ -1558,7 +1559,7 @@ func (sc *ServerContext) handleInvalidDatabaseConfig(ctx context.Context, bucket sc._removeDatabase(ctx, cnf.Name) } -func (sc *ServerContext) bucketNameFromDbName(dbName string) (bucketName string, found bool) { +func (sc *ServerContext) bucketNameFromDbName(ctx context.Context, dbName string) (bucketName string, found bool) { // Minimal representation of config struct to be tolerant of invalid database configurations where we still need to find a database name // see if we find the database in-memory first, otherwise fall back to scanning buckets for db configs sc.lock.RLock() @@ -1578,7 +1579,7 @@ func (sc *ServerContext) bucketNameFromDbName(dbName string) (bucketName string, cfgDbName := &dbConfigNameOnly{} callback := func(bucket string) (exit bool, err error) { - _, err = sc.BootstrapContext.GetConfigName(bucket, sc.Config.Bootstrap.ConfigGroupID, dbName, cfgDbName) + _, err = sc.BootstrapContext.GetConfigName(ctx, bucket, sc.Config.Bootstrap.ConfigGroupID, dbName, cfgDbName) if err != nil && err != base.ErrNotFound { return true, err } @@ -1833,23 +1834,23 @@ func StartServer(ctx context.Context, config *StartupConfig, sc *ServerContext) base.ConsolefCtx(ctx, base.LevelInfo, base.KeyAll, "Starting metrics server on %s", config.API.MetricsInterface) go func() { - if err := sc.Serve(config, config.API.MetricsInterface, CreateMetricHandler(sc)); err != nil { + if err := sc.Serve(ctx, config, config.API.MetricsInterface, CreateMetricHandler(sc)); err != nil { base.ErrorfCtx(ctx, "Error serving the Metrics API: %v", err) } }() base.ConsolefCtx(ctx, base.LevelInfo, base.KeyAll, "Starting admin server on %s", config.API.AdminInterface) go func() { - if err := sc.Serve(config, config.API.AdminInterface, CreateAdminHandler(sc)); err != nil { + if err := sc.Serve(ctx, config, config.API.AdminInterface, CreateAdminHandler(sc)); err != nil { base.ErrorfCtx(ctx, "Error serving the Admin API: %v", err) } }() base.ConsolefCtx(ctx, base.LevelInfo, base.KeyAll, "Starting server on %s ...", config.API.PublicInterface) - return sc.Serve(config, config.API.PublicInterface, CreatePublicHandler(sc)) + return sc.Serve(ctx, config, config.API.PublicInterface, CreatePublicHandler(sc)) } -func sharedBucketDatabaseCheck(sc *ServerContext) (errors error) { +func sharedBucketDatabaseCheck(ctx context.Context, sc *ServerContext) (errors error) { bucketUUIDToDBContext := make(map[string][]*db.DatabaseContext, len(sc.databases_)) for _, dbContext := range sc.databases_ { if uuid, err := dbContext.Bucket.UUID(); err == nil { @@ -1864,7 +1865,7 @@ func sharedBucketDatabaseCheck(sc *ServerContext) (errors error) { multiError = multiError.Append(sharedBucketError) messageFormat := "Bucket %q is shared among databases %s. " + "This may result in unexpected behaviour if security is not defined consistently." - base.WarnfCtx(context.Background(), messageFormat, base.MD(sharedBucket.bucketName), base.MD(sharedBucket.dbNames)) + base.WarnfCtx(ctx, messageFormat, base.MD(sharedBucket.bucketName), base.MD(sharedBucket.dbNames)) } return multiError.ErrorOrNil() } @@ -1930,9 +1931,9 @@ func (sc *ServerContext) _findDuplicateCollections(cnf DatabaseConfig) []string } // PersistentConfigKey returns a document key to use to store database configs -func PersistentConfigKey(groupID string, metadataID string) string { +func PersistentConfigKey(ctx context.Context, groupID string, metadataID string) string { if groupID == "" { - base.WarnfCtx(context.TODO(), "Empty group ID specified for PersistentConfigKey - using %v", PersistentConfigDefaultGroupID) + base.WarnfCtx(ctx, "Empty group ID specified for PersistentConfigKey - using %v", PersistentConfigDefaultGroupID) groupID = PersistentConfigDefaultGroupID } if metadataID == "" { @@ -1943,18 +1944,18 @@ func PersistentConfigKey(groupID string, metadataID string) string { } // Return the persistent config key for a legacy 3.0 persistent config (single database per bucket model) -func PersistentConfigKey30(groupID string) string { +func PersistentConfigKey30(ctx context.Context, groupID string) string { if groupID == "" { - base.WarnfCtx(context.TODO(), "Empty group ID specified for PersistentConfigKey - using %v", PersistentConfigDefaultGroupID) + base.WarnfCtx(ctx, "Empty group ID specified for PersistentConfigKey - using %v", PersistentConfigDefaultGroupID) groupID = PersistentConfigDefaultGroupID } return base.PersistentConfigPrefixWithoutGroupID + groupID } -func HandleSighup() { - for logger, err := range base.RotateLogfiles() { +func HandleSighup(ctx context.Context) { + for logger, err := range base.RotateLogfiles(ctx) { if err != nil { - base.WarnfCtx(context.Background(), "Error rotating %v: %v", logger, err) + base.WarnfCtx(ctx, "Error rotating %v: %v", logger, err) } } } @@ -1972,7 +1973,7 @@ func RegisterSignalHandler(ctx context.Context) { base.InfofCtx(ctx, base.KeyAll, "Handling signal: %v", sig) switch sig { case syscall.SIGHUP: - HandleSighup() + HandleSighup(ctx) default: // Ensure log buffers are flushed before exiting. base.FlushLogBuffers() @@ -1981,3 +1982,16 @@ func RegisterSignalHandler(ctx context.Context) { } }() } + +// toDbConsoleLogConfig converts the console logging from a DbConfig to a DbConsoleLogConfig +func (c *DbConfig) toDbConsoleLogConfig(ctx context.Context) *base.DbConsoleLogConfig { + // Per-database console logging config overrides + if c.Logging != nil && c.Logging.Console != nil { + logKey := base.ToLogKey(ctx, c.Logging.Console.LogKeys) + return &base.DbConsoleLogConfig{ + LogLevel: c.Logging.Console.LogLevel, + LogKeys: &logKey, + } + } + return nil +} diff --git a/rest/config_database.go b/rest/config_database.go index 2af0dfed5d..0d10940fc6 100644 --- a/rest/config_database.go +++ b/rest/config_database.go @@ -9,6 +9,8 @@ package rest import ( + "context" + "github.com/couchbase/sync_gateway/base" "github.com/couchbase/sync_gateway/channels" "github.com/couchbase/sync_gateway/db" @@ -40,7 +42,7 @@ type DatabaseConfig struct { DbConfig } -func (dbc *DatabaseConfig) Redacted() (*DatabaseConfig, error) { +func (dbc *DatabaseConfig) Redacted(ctx context.Context) (*DatabaseConfig, error) { var config DatabaseConfig err := base.DeepCopyInefficient(&config, dbc) @@ -48,7 +50,7 @@ func (dbc *DatabaseConfig) Redacted() (*DatabaseConfig, error) { return nil, err } - err = config.DbConfig.redactInPlace() + err = config.DbConfig.redactInPlace(ctx) if err != nil { return nil, err } @@ -70,13 +72,13 @@ func (dbc *DatabaseConfig) GetCollectionNames() base.ScopeAndCollectionNames { return collections } -func GenerateDatabaseConfigVersionID(previousRevID string, dbConfig *DbConfig) (string, error) { +func GenerateDatabaseConfigVersionID(ctx context.Context, previousRevID string, dbConfig *DbConfig) (string, error) { encodedBody, err := base.JSONMarshalCanonical(dbConfig) if err != nil { return "", err } - previousGen, previousRev := db.ParseRevID(previousRevID) + previousGen, previousRev := db.ParseRevID(ctx, previousRevID) generation := previousGen + 1 hash := db.CreateRevIDWithBytes(generation, previousRev, encodedBody) diff --git a/rest/config_legacy.go b/rest/config_legacy.go index 15538132de..901dd0765b 100644 --- a/rest/config_legacy.go +++ b/rest/config_legacy.go @@ -110,7 +110,7 @@ type UnsupportedServerConfigLegacy struct { // ToStartupConfig returns the given LegacyServerConfig as a StartupConfig and a set of DBConfigs. // The returned configs do not contain any default values - only a direct mapping of legacy config options as they were given. -func (lc *LegacyServerConfig) ToStartupConfig() (*StartupConfig, DbConfigMap, error) { +func (lc *LegacyServerConfig) ToStartupConfig(ctx context.Context) (*StartupConfig, DbConfigMap, error) { // find a database's credentials for bootstrap (this isn't the first database config entry due to map iteration) bootstrapConfigIsSet := false bsc := &BootstrapConfig{} @@ -139,7 +139,7 @@ func (lc *LegacyServerConfig) ToStartupConfig() (*StartupConfig, DbConfigMap, er server, username, password, err := legacyServerAddressUpgrade(*dbConfig.Server) if err != nil { server = *dbConfig.Server - base.ErrorfCtx(context.Background(), "Error upgrading server address: %v", err) + base.ErrorfCtx(ctx, "Error upgrading server address: %v", err) } dbConfig.Server = base.StringPtr(server) @@ -357,19 +357,19 @@ func (clusterConfig *ClusterConfigLegacy) GetCredentials() (string, string, stri } // LoadLegacyServerConfig loads a LegacyServerConfig from either a JSON file or from a URL -func LoadLegacyServerConfig(path string) (config *LegacyServerConfig, err error) { - rc, err := readFromPath(path, false) +func LoadLegacyServerConfig(ctx context.Context, path string) (config *LegacyServerConfig, err error) { + rc, err := readFromPath(ctx, path, false) if err != nil { return nil, err } defer func() { _ = rc.Close() }() - return readLegacyServerConfig(rc) + return readLegacyServerConfig(ctx, rc) } // readLegacyServerConfig returns a validated LegacyServerConfig from an io.Reader -func readLegacyServerConfig(r io.Reader) (config *LegacyServerConfig, err error) { - err = DecodeAndSanitiseConfig(r, &config, true) +func readLegacyServerConfig(ctx context.Context, r io.Reader) (config *LegacyServerConfig, err error) { + err = DecodeAndSanitiseConfig(ctx, r, &config, true) if err != nil { return config, err } @@ -382,10 +382,10 @@ func readLegacyServerConfig(r io.Reader) (config *LegacyServerConfig, err error) // setupServerConfig parses command-line flags, reads the optional configuration file, // performs the config validation and database setup. -func setupServerConfig(args []string) (config *LegacyServerConfig, err error) { +func setupServerConfig(ctx context.Context, args []string) (config *LegacyServerConfig, err error) { var unknownFieldsErr error - config, err = ParseCommandLine(args, flag.ExitOnError) + config, err = ParseCommandLine(ctx, args, flag.ExitOnError) if pkgerrors.Cause(err) == base.ErrUnknownField { unknownFieldsErr = err } else if err != nil { @@ -401,32 +401,32 @@ func setupServerConfig(args []string) (config *LegacyServerConfig, err error) { // Validation var multiError *base.MultiError multiError = multiError.Append(config.validate()) - multiError = multiError.Append(config.setupAndValidateDatabases()) + multiError = multiError.Append(config.setupAndValidateDatabases(ctx)) if multiError.ErrorOrNil() != nil { - base.ErrorfCtx(context.Background(), "Error during config validation: %v", multiError) + base.ErrorfCtx(ctx, "Error during config validation: %v", multiError) return nil, fmt.Errorf("error(s) during config validation: %v", multiError) } return config, nil } -func setupAndValidateDatabases(databases DbConfigMap) error { +func setupAndValidateDatabases(ctx context.Context, databases DbConfigMap) error { for name, dbConfig := range databases { - if err := dbConfig.setup(name, BootstrapConfig{}, nil, nil, false); err != nil { + if err := dbConfig.setup(ctx, name, BootstrapConfig{}, nil, nil, false); err != nil { return err } - if err := dbConfig.validate(context.Background(), false); err != nil { + if err := dbConfig.validate(ctx, false); err != nil { return err } } return nil } -func (config *LegacyServerConfig) setupAndValidateDatabases() error { +func (config *LegacyServerConfig) setupAndValidateDatabases(ctx context.Context) error { if config == nil { return nil } - return setupAndValidateDatabases(config.Databases) + return setupAndValidateDatabases(ctx, config.Databases) } // validate validates the given server config and returns all invalid options as a slice of errors @@ -487,7 +487,7 @@ func (self *LegacyServerConfig) MergeWith(other *LegacyServerConfig) error { return nil } -func (sc *LegacyServerConfig) Redacted() (*LegacyServerConfig, error) { +func (sc *LegacyServerConfig) Redacted(ctx context.Context) (*LegacyServerConfig, error) { var config LegacyServerConfig err := base.DeepCopyInefficient(&config, sc) @@ -496,7 +496,7 @@ func (sc *LegacyServerConfig) Redacted() (*LegacyServerConfig, error) { } for i := range config.Databases { - config.Databases[i], err = config.Databases[i].Redacted() + config.Databases[i], err = config.Databases[i].Redacted(ctx) if err != nil { return nil, err } @@ -506,7 +506,7 @@ func (sc *LegacyServerConfig) Redacted() (*LegacyServerConfig, error) { } // Reads the command line flags and the optional config file. -func ParseCommandLine(args []string, handling flag.ErrorHandling) (*LegacyServerConfig, error) { +func ParseCommandLine(ctx context.Context, args []string, handling flag.ErrorHandling) (*LegacyServerConfig, error) { flagSet := flag.NewFlagSet(args[0], handling) _ = flagSet.Bool("disable_persistent_config", false, "") @@ -541,7 +541,7 @@ func ParseCommandLine(args []string, handling flag.ErrorHandling) (*LegacyServer if flagSet.NArg() > 0 { // Read the configuration file(s), if any: for _, filename := range flagSet.Args() { - newConfig, newConfigErr := LoadLegacyServerConfig(filename) + newConfig, newConfigErr := LoadLegacyServerConfig(ctx, filename) if newConfigErr != nil { return config, pkgerrors.WithMessage(newConfigErr, fmt.Sprintf("Error reading config file %s", filename)) diff --git a/rest/config_legacy_test.go b/rest/config_legacy_test.go index 4e129e9ea4..d6a7b47d75 100644 --- a/rest/config_legacy_test.go +++ b/rest/config_legacy_test.go @@ -120,7 +120,7 @@ func TestLegacyConfigToStartupConfig(t *testing.T) { t.Run(test.name, func(t *testing.T) { lc := &test.input - migratedStartupConfig, _, err := lc.ToStartupConfig() + migratedStartupConfig, _, err := lc.ToStartupConfig(base.TestCtx(t)) require.NoError(t, err) config := test.base @@ -253,7 +253,7 @@ func TestLegacyConfigXattrsDefault(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { lc := LegacyServerConfig{Databases: DbConfigMap{"db": &DbConfig{EnableXattrs: test.xattrs}}} - _, dbs, err := lc.ToStartupConfig() + _, dbs, err := lc.ToStartupConfig(base.TestCtx(t)) require.NoError(t, err) db, ok := dbs["db"] @@ -277,7 +277,7 @@ func TestSGReplicateValidation(t *testing.T) { ] }`) - _, err := readLegacyServerConfig(configReader) + _, err := readLegacyServerConfig(base.TestCtx(t), configReader) require.Error(t, err) assert.Contains(t, err.Error(), errText) } @@ -328,17 +328,18 @@ func TestLegacyGuestUserMigration(t *testing.T) { err := os.WriteFile(configPath, []byte(config), os.FileMode(0644)) require.NoError(t, err) - sc, _, _, _, err := automaticConfigUpgrade(configPath) + ctx := base.TestCtx(t) + sc, _, _, _, err := automaticConfigUpgrade(ctx, configPath) require.NoError(t, err) - cluster, err := CreateCouchbaseClusterFromStartupConfig(sc, base.PerUseClusterConnections) + cluster, err := CreateCouchbaseClusterFromStartupConfig(ctx, sc, base.PerUseClusterConnections) require.NoError(t, err) bootstrap := bootstrapContext{ Connection: cluster, } var dbConfig DatabaseConfig - _, err = bootstrap.GetConfig(tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) + _, err = bootstrap.GetConfig(base.TestCtx(t), tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) require.NoError(t, err) assert.Equal(t, &expected, dbConfig.Guest) @@ -432,12 +433,12 @@ func TestLegacyConfigPrinciplesMigration(t *testing.T) { require.NoError(t, err) // Copy behaviour of serverMainPersistentConfig - upgrade config, pass legacy users and roles in to addLegacyPrinciples (after server context is created) - _, _, users, roles, err := automaticConfigUpgrade(configPath) + _, _, users, roles, err := automaticConfigUpgrade(ctx, configPath) require.NoError(t, err) rt.ServerContext().addLegacyPrincipals(ctx, users, roles) // Check that principles all exist on bucket - authenticator := auth.NewAuthenticator(bucket.DefaultDataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + authenticator := auth.NewAuthenticator(bucket.DefaultDataStore(), nil, rt.GetDatabase().AuthenticatorOptions(ctx)) for _, name := range expectedUsers { user, err := authenticator.GetUser(name) assert.NoError(t, err) @@ -486,7 +487,7 @@ func TestLegacyReplicationConfigValidation(t *testing.T) { lc := LegacyServerConfig{Databases: DbConfigMap{"db": &DbConfig{Replications: test.input}}} - _, _, err := lc.ToStartupConfig() + _, _, err := lc.ToStartupConfig(base.TestCtx(t)) fmt.Println(err) if test.expectError { assert.Error(t, err) diff --git a/rest/config_manager.go b/rest/config_manager.go index 47bc311694..dcbdf2499a 100644 --- a/rest/config_manager.go +++ b/rest/config_manager.go @@ -19,7 +19,7 @@ import ( // ConfigManager should be used for any read/write of persisted database configuration files type ConfigManager interface { // GetConfig fetches a database config for a given bucket and config group ID, along with the CAS of the config document. Does not enforce version match with registry. - GetConfig(bucket, groupID, dbName string, config *DatabaseConfig) (cas uint64, err error) + GetConfig(ctx context.Context, bucket, groupID, dbName string, config *DatabaseConfig) (cas uint64, err error) // GetDatabaseConfigs returns all configs for the bucket and config group. Enforces version match with registry. GetDatabaseConfigs(ctx context.Context, bucketName, groupID string) ([]*DatabaseConfig, error) // InsertConfig saves a new database config for a given bucket and config group ID. @@ -42,15 +42,15 @@ const configFetchMaxRetryAttempts = 5 // Maximum number of retries due to regis const defaultMetadataID = "_default" // GetConfig fetches a database name for a given bucket and config group ID. -func (b *bootstrapContext) GetConfigName(bucketName, groupID, dbName string, configName *dbConfigNameOnly) (cas uint64, err error) { - return b.Connection.GetMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), configName) +func (b *bootstrapContext) GetConfigName(ctx context.Context, bucketName, groupID, dbName string, configName *dbConfigNameOnly) (cas uint64, err error) { + return b.Connection.GetMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), configName) } // GetConfig fetches a database config for a given bucket and config group ID, along with the CAS of the config document. // GetConfig does *not* validate that config version matches registry version - operations requiring synchronization // with registry should use getRegistryAndDatabase, or getConfig with the required version -func (b *bootstrapContext) GetConfig(bucketName, groupID, dbName string, config *DatabaseConfig) (cas uint64, err error) { - return b.Connection.GetMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), config) +func (b *bootstrapContext) GetConfig(ctx context.Context, bucketName, groupID, dbName string, config *DatabaseConfig) (cas uint64, err error) { + return b.Connection.GetMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), config) } // InsertConfig saves a new database config for a given bucket and config group ID. This is a three-step process: @@ -60,7 +60,7 @@ func (b *bootstrapContext) GetConfig(bucketName, groupID, dbName string, config func (b *bootstrapContext) InsertConfig(ctx context.Context, bucketName, groupID string, config *DatabaseConfig) (newCAS uint64, err error) { dbName := config.Name attempts := 0 - ctx = b.addDatabaseLogContext(ctx, dbName) + ctx = b.addDatabaseLogContext(ctx, &config.DbConfig) for attempts < configUpdateMaxRetryAttempts { attempts++ base.InfofCtx(ctx, base.KeyConfig, "InsertConfig into bucket %s starting (attempt %d/%d)", bucketName, attempts, configUpdateMaxRetryAttempts) @@ -98,7 +98,7 @@ func (b *bootstrapContext) InsertConfig(ctx context.Context, bucketName, groupID } // Persist registry - writeErr := b.setGatewayRegistry(bucketName, registry) + writeErr := b.setGatewayRegistry(ctx, bucketName, registry) if writeErr == nil { break } @@ -117,7 +117,7 @@ func (b *bootstrapContext) InsertConfig(ctx context.Context, bucketName, groupID } } // Step 3. Write the database config - cas, configErr := b.Connection.InsertMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), config) + cas, configErr := b.Connection.InsertMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), config) if configErr != nil { base.InfofCtx(ctx, base.KeyConfig, "Insert for database config returned error %v", configErr) } else { @@ -137,7 +137,6 @@ func (b *bootstrapContext) UpdateConfig(ctx context.Context, bucketName, groupID var previousVersion string attempts := 0 - ctx = b.addDatabaseLogContext(ctx, dbName) outer: for attempts < configUpdateMaxRetryAttempts { attempts++ @@ -145,6 +144,9 @@ outer: // Step 1. Fetch registry and databases - enforces registry/config synchronization var existingConfig *DatabaseConfig registry, existingConfig, err = b.getRegistryAndDatabase(ctx, bucketName, groupID, dbName) + if existingConfig != nil { + ctx = b.addDatabaseLogContext(ctx, &existingConfig.DbConfig) + } if err != nil { base.InfofCtx(ctx, base.KeyConfig, "UpdateConfig unable to retrieve registry and database: %v", err) return 0, err @@ -180,7 +182,7 @@ outer: } // Persist registry - writeErr := b.setGatewayRegistry(bucketName, registry) + writeErr := b.setGatewayRegistry(ctx, bucketName, registry) if writeErr == nil { break } @@ -200,7 +202,7 @@ outer: } // Step 2. Update the config document - casOut, err := b.Connection.WriteMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), updatedConfig.cfgCas, updatedConfig) + casOut, err := b.Connection.WriteMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), updatedConfig.cfgCas, updatedConfig) if err != nil { base.InfofCtx(ctx, base.KeyConfig, "Write for database config returned error %v", err) return 0, err @@ -212,7 +214,7 @@ outer: if err != nil { return 0, fmt.Errorf("Error removing previous version of config group: %s, database: %s from registry after successful update: %w", base.MD(groupID), base.MD(dbName), err) } - writeErr := b.setGatewayRegistry(bucketName, registry) + writeErr := b.setGatewayRegistry(ctx, bucketName, registry) if writeErr != nil { return 0, fmt.Errorf("Error persisting removal of previous version of config group: %s, database: %s from registry after successful update: %w", base.MD(groupID), base.MD(dbName), writeErr) } @@ -229,7 +231,6 @@ func (b *bootstrapContext) DeleteConfig(ctx context.Context, bucketName, groupID var existingCas uint64 var registry *GatewayRegistry attempts := 0 - ctx = b.addDatabaseLogContext(ctx, dbName) outer: for attempts < configUpdateMaxRetryAttempts { attempts++ @@ -254,7 +255,7 @@ outer: } // Persist registry - writeErr := b.setGatewayRegistry(bucketName, registry) + writeErr := b.setGatewayRegistry(ctx, bucketName, registry) if writeErr == nil { break } @@ -272,7 +273,7 @@ outer: } } - err = b.Connection.DeleteMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), existingCas) + err = b.Connection.DeleteMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), existingCas) if err != nil { base.InfofCtx(ctx, base.KeyConfig, "Delete for database config returned error %v", err) return err @@ -284,7 +285,7 @@ outer: if !found { base.InfofCtx(ctx, base.KeyConfig, "Database not found in registry during finalization") } else { - writeErr := b.setGatewayRegistry(bucketName, registry) + writeErr := b.setGatewayRegistry(ctx, bucketName, registry) if writeErr != nil { return fmt.Errorf("Error persisting removal of previous version of config group: %s, database: %s from registry after successful delete: %w", base.MD(groupID), base.MD(dbName), writeErr) } @@ -327,7 +328,7 @@ func (b *bootstrapContext) GetDatabaseConfigs(ctx context.Context, bucketName, g // Check for legacy config file var legacyConfig DatabaseConfig var legacyDbName string - cas, legacyErr := b.Connection.GetMetadataDocument(bucketName, PersistentConfigKey(groupID, ""), &legacyConfig) + cas, legacyErr := b.Connection.GetMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, ""), &legacyConfig) if legacyErr != nil && legacyErr != base.ErrNotFound { return nil, fmt.Errorf("Error checking for legacy config for %s, %s: %w", base.MD(bucketName), base.MD(groupID), legacyErr) } @@ -388,8 +389,8 @@ func (b *bootstrapContext) getConfigVersionWithRetry(ctx context.Context, bucket retryWorker := func() (shouldRetry bool, err error, value interface{}) { config := &DatabaseConfig{} - metadataKey := PersistentConfigKey(groupID, dbName) - cas, err := b.Connection.GetMetadataDocument(bucketName, metadataKey, config) + metadataKey := PersistentConfigKey(ctx, groupID, dbName) + cas, err := b.Connection.GetMetadataDocument(ctx, bucketName, metadataKey, config) if err == base.ErrNotFound { return true, base.ErrConfigRegistryRollback, nil } @@ -404,8 +405,8 @@ func (b *bootstrapContext) getConfigVersionWithRetry(ctx context.Context, bucket } // For version mismatch, handling depends on whether config has newer or older version than requested - requestedGen, _ := db.ParseRevID(version) - currentGen, _ := db.ParseRevID(config.Version) + requestedGen, _ := db.ParseRevID(ctx, version) + currentGen, _ := db.ParseRevID(ctx, config.Version) if currentGen > requestedGen { // If the config has a newer version than requested, return the config but alert caller that they have // requested a stale version. @@ -418,6 +419,7 @@ func (b *bootstrapContext) getConfigVersionWithRetry(ctx context.Context, bucket // Kick off the retry loop err, retryResult := base.RetryLoop( + ctx, "Wait for config version match", retryWorker, base.CreateDoublingSleeperDurationFunc(50, timeout), @@ -444,7 +446,7 @@ func (b *bootstrapContext) getConfigVersionWithRetry(ctx context.Context, bucket // triggers registry rollback and returns rollback error func (b *bootstrapContext) getDatabaseConfig(ctx context.Context, bucketName, groupID, dbName string, version string, registry *GatewayRegistry) (*DatabaseConfig, error) { - ctx = b.addDatabaseLogContext(ctx, dbName) + ctx = b.addDatabaseLogContext(ctx, &DbConfig{Name: dbName}) config, err := b.getConfigVersionWithRetry(ctx, bucketName, groupID, dbName, version) if err != nil { if err == base.ErrConfigRegistryRollback { @@ -479,7 +481,7 @@ func (b *bootstrapContext) waitForConfigDelete(ctx context.Context, bucketName, retryWorker := func() (shouldRetry bool, err error, value interface{}) { config := &DatabaseConfig{} - cas, getErr := b.Connection.GetMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), config) + cas, getErr := b.Connection.GetMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), config) // Success case - delete has been completed if getErr == base.ErrNotFound { return false, nil, nil @@ -498,6 +500,7 @@ func (b *bootstrapContext) waitForConfigDelete(ctx context.Context, bucketName, // Kick off the retry loop err, retryResult := base.RetryLoop( + ctx, "Wait for config version match", retryWorker, base.CreateDoublingSleeperDurationFunc(50, timeout), @@ -510,7 +513,7 @@ func (b *bootstrapContext) waitForConfigDelete(ctx context.Context, bucketName, return fmt.Errorf("Unable to convert returned cas of type %T to uint64", retryResult) } - err = b.Connection.DeleteMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), existingCas) + err = b.Connection.DeleteMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), existingCas) if err != nil { return err } @@ -541,7 +544,7 @@ func (b *bootstrapContext) rollbackRegistry(ctx context.Context, bucketName, gro } else { // Mark the database config being rolled back first to update CAS, to ensure a slow writer doesn't succeed while we're rolling back. // Use the database name property for the update, as this is otherwise immutable. - casOut, err := b.Connection.TouchMetadataDocument(bucketName, PersistentConfigKey(groupID, dbName), "name", dbName, config.cfgCas) + casOut, err := b.Connection.TouchMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, dbName), "name", dbName, config.cfgCas) if err != nil { return fmt.Errorf("Rollback cancelled - document has been updated") } @@ -559,7 +562,7 @@ func (b *bootstrapContext) rollbackRegistry(ctx context.Context, bucketName, gro } // Attempt to persist the registry - casOut, err := b.Connection.WriteMetadataDocument(bucketName, base.SGRegistryKey, registry.cas, registry) + casOut, err := b.Connection.WriteMetadataDocument(ctx, bucketName, base.SGRegistryKey, registry.cas, registry) if err == nil { registry.cas = casOut base.InfofCtx(ctx, base.KeyConfig, "Successful config registry rollback for bucket: %s, configGroup: %s, db: %s", base.MD(bucketName), base.MD(groupID), base.MD(dbName)) @@ -568,10 +571,10 @@ func (b *bootstrapContext) rollbackRegistry(ctx context.Context, bucketName, gro } // getGatewayRegistry returns the database registry document for the bucket -func (b *bootstrapContext) getGatewayRegistry(_ context.Context, bucketName string) (result *GatewayRegistry, err error) { +func (b *bootstrapContext) getGatewayRegistry(ctx context.Context, bucketName string) (result *GatewayRegistry, err error) { registry := &GatewayRegistry{} - cas, getErr := b.Connection.GetMetadataDocument(bucketName, base.SGRegistryKey, registry) + cas, getErr := b.Connection.GetMetadataDocument(ctx, bucketName, base.SGRegistryKey, registry) if getErr != nil { if getErr == base.ErrNotFound { return NewGatewayRegistry(), nil @@ -584,7 +587,7 @@ func (b *bootstrapContext) getGatewayRegistry(_ context.Context, bucketName stri } // getGatewayRegistry returns the database registry document for the bucket -func (b *bootstrapContext) setGatewayRegistry(bucketName string, registry *GatewayRegistry) (err error) { +func (b *bootstrapContext) setGatewayRegistry(ctx context.Context, bucketName string, registry *GatewayRegistry) (err error) { cas := uint64(0) if registry != nil { @@ -594,9 +597,9 @@ func (b *bootstrapContext) setGatewayRegistry(bucketName string, registry *Gatew var casOut uint64 var writeErr error if cas == 0 { - casOut, writeErr = b.Connection.InsertMetadataDocument(bucketName, base.SGRegistryKey, registry) + casOut, writeErr = b.Connection.InsertMetadataDocument(ctx, bucketName, base.SGRegistryKey, registry) } else { - casOut, writeErr = b.Connection.WriteMetadataDocument(bucketName, base.SGRegistryKey, cas, registry) + casOut, writeErr = b.Connection.WriteMetadataDocument(ctx, bucketName, base.SGRegistryKey, cas, registry) } if writeErr != nil { @@ -685,8 +688,8 @@ func (b *bootstrapContext) getRegistryAndDatabase(ctx context.Context, bucketNam } -func (b *bootstrapContext) addDatabaseLogContext(ctx context.Context, dbName string) context.Context { - return base.DatabaseLogCtx(ctx, dbName, nil) +func (b *bootstrapContext) addDatabaseLogContext(ctx context.Context, config *DbConfig) context.Context { + return base.DatabaseLogCtx(ctx, config.Name, config.toDbConsoleLogConfig(ctx)) } func (b *bootstrapContext) ComputeMetadataIDForDbConfig(ctx context.Context, config *DbConfig) (string, error) { @@ -739,7 +742,7 @@ func (b *bootstrapContext) computeMetadataID(ctx context.Context, registry *Gate // If _default._default is already associated with a metadataID, return standard metadata ID bucketName := *config.Bucket - exists, err := b.Connection.KeyExists(bucketName, base.SGSyncInfo) + exists, err := b.Connection.KeyExists(ctx, bucketName, base.SGSyncInfo) if err != nil { base.WarnfCtx(ctx, "Error checking whether metadataID is already defined for default collection - using standard metadataID. Error: %v", err) return standardMetadataID @@ -749,7 +752,7 @@ func (b *bootstrapContext) computeMetadataID(ctx context.Context, registry *Gate } // If legacy _sync:seq doesn't exist, use the standard ID - legacySyncSeqExists, _ := b.Connection.KeyExists(bucketName, base.DefaultMetadataKeys.SyncSeqKey()) + legacySyncSeqExists, _ := b.Connection.KeyExists(ctx, bucketName, base.DefaultMetadataKeys.SyncSeqKey()) if !legacySyncSeqExists { return standardMetadataID } diff --git a/rest/config_manager_test.go b/rest/config_manager_test.go index c46cc34e17..3b04b6caba 100644 --- a/rest/config_manager_test.go +++ b/rest/config_manager_test.go @@ -49,7 +49,7 @@ func TestBootstrapConfig(t *testing.T) { var dbConfig1 *DatabaseConfig - _, err = bootstrapContext.GetConfig(bucketName, configGroup1, db1Name, dbConfig1) + _, err = bootstrapContext.GetConfig(ctx, bucketName, configGroup1, db1Name, dbConfig1) require.Error(t, err) } diff --git a/rest/config_startup.go b/rest/config_startup.go index 5da211f21d..54326248e1 100644 --- a/rest/config_startup.go +++ b/rest/config_startup.go @@ -197,8 +197,8 @@ func (sc *StartupConfig) IsServerless() bool { return base.BoolDefault(sc.Unsupported.Serverless.Enabled, false) } -func LoadStartupConfigFromPath(path string) (*StartupConfig, error) { - rc, err := readFromPath(path, false) +func LoadStartupConfigFromPath(ctx context.Context, path string) (*StartupConfig, error) { + rc, err := readFromPath(ctx, path, false) if err != nil { return nil, err } @@ -206,7 +206,7 @@ func LoadStartupConfigFromPath(path string) (*StartupConfig, error) { defer func() { _ = rc.Close() }() var sc StartupConfig - err = DecodeAndSanitiseConfig(rc, &sc, true) + err = DecodeAndSanitiseConfig(ctx, rc, &sc, true) return &sc, err } @@ -233,7 +233,7 @@ func NewEmptyStartupConfig() StartupConfig { // setGlobalConfig will set global variables and other settings based on the given StartupConfig. // We should try to keep these minimal where possible, and favour ServerContext-scoped values. -func setGlobalConfig(sc *StartupConfig) error { +func setGlobalConfig(ctx context.Context, sc *StartupConfig) error { // Per-process limits, can't be scoped any narrower. if os.Getenv("GOMAXPROCS") == "" && runtime.GOMAXPROCS(0) == 1 { @@ -241,12 +241,12 @@ func setGlobalConfig(sc *StartupConfig) error { cpus := runtime.NumCPU() if cpus > 1 { runtime.GOMAXPROCS(cpus) - base.InfofCtx(context.Background(), base.KeyAll, "Configured Go to use all %d CPUs; setenv GOMAXPROCS to override this", cpus) + base.InfofCtx(ctx, base.KeyAll, "Configured Go to use all %d CPUs; setenv GOMAXPROCS to override this", cpus) } } - if _, err := base.SetMaxFileDescriptors(sc.MaxFileDescriptors); err != nil { - base.ErrorfCtx(context.Background(), "Error setting MaxFileDescriptors to %d: %v", sc.MaxFileDescriptors, err) + if _, err := base.SetMaxFileDescriptors(ctx, sc.MaxFileDescriptors); err != nil { + base.ErrorfCtx(ctx, "Error setting MaxFileDescriptors to %d: %v", sc.MaxFileDescriptors, err) } // TODO: Remove with GoCB DCP switch @@ -256,7 +256,7 @@ func setGlobalConfig(sc *StartupConfig) error { // Given unscoped usage of base.JSON functions, this can't be scoped. if base.BoolDefault(sc.Unsupported.UseStdlibJSON, false) { - base.InfofCtx(context.Background(), base.KeyAll, "Using the stdlib JSON package") + base.InfofCtx(ctx, base.KeyAll, "Using the stdlib JSON package") base.UseStdlibJSON = true } diff --git a/rest/config_test.go b/rest/config_test.go index bc022d5996..43f46820d1 100644 --- a/rest/config_test.go +++ b/rest/config_test.go @@ -12,7 +12,6 @@ package rest import ( "bytes" - "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -90,7 +89,7 @@ func TestReadServerConfig(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { buf := bytes.NewBufferString(test.config) - _, err := readLegacyServerConfig(buf) + _, err := readLegacyServerConfig(base.TestCtx(t), buf) // stdlib/CE specific error checking expectedErr := test.errStdlib @@ -135,10 +134,11 @@ func TestConfigValidation(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx := base.TestCtx(t) buf := bytes.NewBufferString(test.config) - config, err := readLegacyServerConfig(buf) + config, err := readLegacyServerConfig(ctx, buf) assert.NoError(t, err) - errorMessages := config.setupAndValidateDatabases() + errorMessages := config.setupAndValidateDatabases(ctx) if test.err != "" { require.NotNil(t, errorMessages) multiError, ok := errorMessages.(*base.MultiError) @@ -154,12 +154,12 @@ func TestConfigValidation(t *testing.T) { func TestConfigValidationDeltaSync(t *testing.T) { jsonConfig := `{"databases": {"db": {"delta_sync": {"enabled": true}}}}` - + ctx := base.TestCtx(t) buf := bytes.NewBufferString(jsonConfig) - config, err := readLegacyServerConfig(buf) + config, err := readLegacyServerConfig(ctx, buf) assert.NoError(t, err) - errorMessages := config.setupAndValidateDatabases() + errorMessages := config.setupAndValidateDatabases(ctx) require.NoError(t, errorMessages) require.NotNil(t, config.Databases["db"]) @@ -175,12 +175,12 @@ func TestConfigValidationDeltaSync(t *testing.T) { func TestConfigValidationCache(t *testing.T) { jsonConfig := `{"databases": {"db": {"cache": {"rev_cache": {"size": 0}, "channel_cache": {"max_number": 100, "compact_high_watermark_pct": 95, "compact_low_watermark_pct": 25}}}}}` - + ctx := base.TestCtx(t) buf := bytes.NewBufferString(jsonConfig) - config, err := readLegacyServerConfig(buf) + config, err := readLegacyServerConfig(ctx, buf) assert.NoError(t, err) - errorMessages := config.setupAndValidateDatabases() + errorMessages := config.setupAndValidateDatabases(ctx) require.NoError(t, errorMessages) require.NotNil(t, config.Databases["db"]) @@ -223,12 +223,12 @@ func TestConfigValidationCache(t *testing.T) { func TestConfigValidationImport(t *testing.T) { jsonConfig := `{"databases": {"db": {"enable_shared_bucket_access":true, "import_docs": true, "import_partitions": 32}}}` - + ctx := base.TestCtx(t) buf := bytes.NewBufferString(jsonConfig) - config, err := readLegacyServerConfig(buf) + config, err := readLegacyServerConfig(ctx, buf) assert.NoError(t, err) - errorMessages := config.setupAndValidateDatabases() + errorMessages := config.setupAndValidateDatabases(ctx) require.NoError(t, errorMessages) require.NotNil(t, config.Databases["db"]) @@ -289,10 +289,11 @@ func TestConfigValidationImportPartitions(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx := base.TestCtx(t) buf := bytes.NewBufferString(test.config) - config, err := readLegacyServerConfig(buf) + config, err := readLegacyServerConfig(ctx, buf) assert.NoError(t, err) - errorMessages := config.setupAndValidateDatabases() + errorMessages := config.setupAndValidateDatabases(ctx) if test.err != "" { require.NotNil(t, errorMessages) multiError, ok := errorMessages.(*base.MultiError) @@ -452,7 +453,7 @@ func TestLoadServerConfigExamples(t *testing.T) { } t.Run(configPath, func(tt *testing.T) { - _, err := LoadLegacyServerConfig(configPath) + _, err := LoadLegacyServerConfig(base.TestCtx(t), configPath) assert.NoError(tt, err, "unexpected error validating example config") }) @@ -597,7 +598,7 @@ func TestAutoImportEnabled(t *testing.T) { t.Run(test.name, func(t *testing.T) { dbConfig := &DbConfig{AutoImport: test.configValue} - got, err := dbConfig.AutoImportEnabled() + got, err := dbConfig.AutoImportEnabled(base.TestCtx(t)) assert.Equal(t, test.hasError, err != nil, "unexpected error from AutoImportEnabled") assert.Equal(t, test.expected, got, "unexpected value from AutoImportEnabled") }) @@ -720,9 +721,10 @@ func TestServerConfigValidate(t *testing.T) { } func TestSetupAndValidateDatabases(t *testing.T) { + ctx := base.TestCtx(t) // No error will be returned if the server config itself is nil var sc *LegacyServerConfig - errs := sc.setupAndValidateDatabases() + errs := sc.setupAndValidateDatabases(ctx) assert.Nil(t, errs) // Simulate invalid control character in URL while validating and setting up databases; @@ -732,7 +734,7 @@ func TestSetupAndValidateDatabases(t *testing.T) { databases["db1"] = &DbConfig{Name: "db1", BucketConfig: *bc} sc = &LegacyServerConfig{Databases: databases} - validationError := sc.setupAndValidateDatabases() + validationError := sc.setupAndValidateDatabases(ctx) require.NotNil(t, validationError) assert.Contains(t, validationError.Error(), "invalid control character in URL") } @@ -767,7 +769,7 @@ func TestParseCommandLine(t *testing.T) { "--logFilePath", logFilePath, "--pretty"} - config, err := ParseCommandLine(args, flag.ContinueOnError) + config, err := ParseCommandLine(base.TestCtx(t), args, flag.ContinueOnError) require.NoError(t, err, "Parsing commandline arguments without any config file") assert.Equal(t, interfaceAddress, *config.Interface) assert.Equal(t, adminInterface, *config.AdminInterface) @@ -812,19 +814,19 @@ func TestGetCredentialsFromClusterConfig(t *testing.T) { func TestSetMaxFileDescriptors(t *testing.T) { var maxFDs *uint64 - err := SetMaxFileDescriptors(maxFDs) + err := SetMaxFileDescriptors(base.TestCtx(t), maxFDs) assert.NoError(t, err, "Sets file descriptor limit to default when requested soft limit is nil") // Set MaxFileDescriptors maxFDsHigher := DefaultMaxFileDescriptors + 1 - err = SetMaxFileDescriptors(&maxFDsHigher) + err = SetMaxFileDescriptors(base.TestCtx(t), &maxFDsHigher) assert.NoError(t, err, "Error setting MaxFileDescriptors") } func TestParseCommandLineWithMissingConfig(t *testing.T) { // Parse command line options with unknown sync gateway configuration file args := []string{"sync_gateway", "missing-sync-gateway.conf"} - config, err := ParseCommandLine(args, flag.ContinueOnError) + config, err := ParseCommandLine(base.TestCtx(t), args, flag.ContinueOnError) require.Error(t, err, "Trying to read configuration file which doesn't exist") assert.Nil(t, config) } @@ -847,7 +849,7 @@ func TestParseCommandLineWithBadConfigContent(t *testing.T) { }() args := []string{"sync_gateway", configFile.Name()} - config, err := ParseCommandLine(args, flag.ContinueOnError) + config, err := ParseCommandLine(base.TestCtx(t), args, flag.ContinueOnError) assert.Error(t, err, "Parsing configuration file with an unknown field") assert.Nil(t, config) } @@ -902,7 +904,7 @@ func TestParseCommandLineWithConfigContent(t *testing.T) { "--profileInterface", profileInterface, configFile.Name()} - config, err := ParseCommandLine(args, flag.ContinueOnError) + config, err := ParseCommandLine(base.TestCtx(t), args, flag.ContinueOnError) require.NoError(t, err, "while parsing commandline options") assert.Equal(t, interfaceAddress, *config.Interface) assert.Equal(t, adminInterface, *config.AdminInterface) @@ -990,10 +992,10 @@ func TestValidateServerContextSharedBuckets(t *testing.T) { NumIndexReplicas: base.UintPtr(0), }, } + ctx := base.TestCtx(t) - require.Nil(t, setupAndValidateDatabases(databases), "Unexpected error while validating databases") + require.Nil(t, setupAndValidateDatabases(ctx, databases), "Unexpected error while validating databases") - ctx := base.TestCtx(t) sc := NewServerContext(ctx, config, false) defer sc.Close(ctx) for _, dbConfig := range databases { @@ -1001,7 +1003,7 @@ func TestValidateServerContextSharedBuckets(t *testing.T) { require.NoError(t, err, "Couldn't add database from config") } - sharedBucketErrors := sharedBucketDatabaseCheck(sc) + sharedBucketErrors := sharedBucketDatabaseCheck(ctx, sc) require.NotNil(t, sharedBucketErrors) multiError, ok := sharedBucketErrors.(*base.MultiError) require.NotNil(t, ok) @@ -1017,7 +1019,7 @@ func TestParseCommandLineWithIllegalOptionBucket(t *testing.T) { "sync_gateway", "--bucket", "sync_gateway", // Bucket option has been removed } - config, err := ParseCommandLine(args, flag.ContinueOnError) + config, err := ParseCommandLine(base.TestCtx(t), args, flag.ContinueOnError) assert.Error(t, err, "Parsing commandline arguments without any config file") assert.Empty(t, config, "Couldn't parse commandline arguments") } @@ -1077,7 +1079,7 @@ func TestEnvDefaultExpansion(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - actualValue, err := envDefaultExpansion(test.envKey, func(s string) string { return test.envValue }) + actualValue, err := envDefaultExpansion(base.TestCtx(t), test.envKey, func(s string) string { return test.envValue }) require.Equal(t, err, test.expectedError) assert.Equal(t, test.expectedValue, actualValue) }) @@ -1281,7 +1283,7 @@ func TestExpandEnv(t *testing.T) { require.Equal(t, v, value, "Unexpected value set for environment variable %q", k) } // Check environment variable substitutions. - actualConfig, err := expandEnv(test.inputConfig) + actualConfig, err := expandEnv(base.TestCtx(t), test.inputConfig) if test.expectedError != nil { errs, ok := err.(*base.MultiError) require.True(t, ok) @@ -1414,7 +1416,7 @@ func TestConfigGroupIDValidation(t *testing.T) { UseTLSServer: base.BoolPtr(base.ServerIsTLS(base.UnitTestUrl())), }, } - err := sc.Validate(isEnterpriseEdition) + err := sc.Validate(base.TestCtx(t), isEnterpriseEdition) if test.expectedError != "" { require.Error(t, err) assert.Contains(t, err.Error(), test.expectedError) @@ -1465,7 +1467,7 @@ func TestClientTLSMissing(t *testing.T) { if test.tlsCert { config.API.HTTPS.TLSCertPath = "test.cert" } - err := config.Validate(base.IsEnterpriseEdition()) + err := config.Validate(base.TestCtx(t), base.IsEnterpriseEdition()) if test.expectError { require.Error(t, err) assert.Contains(t, err.Error(), errorTLSOneMissing) @@ -1656,7 +1658,7 @@ func TestLoadJavaScript(t *testing.T) { teardownFn() } }() - js, err := loadJavaScript(inputJavaScriptOrPath, test.insecureSkipVerify) + js, err := loadJavaScript(base.TestCtx(t), inputJavaScriptOrPath, test.insecureSkipVerify) if test.errExpected != nil { requireErrorWithX509UnknownAuthority(t, err, test.errExpected) } else { @@ -1716,7 +1718,7 @@ func TestSetupDbConfigCredentials(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.dbConfig.setup(test.dbConfig.Name, test.bootstrapConfig, test.credentialsConfig, nil, false) + err := test.dbConfig.setup(base.TestCtx(t), test.dbConfig.Name, test.bootstrapConfig, test.credentialsConfig, nil, false) require.NoError(t, err) if test.expectX509 { assert.Equal(t, "", test.dbConfig.Username) @@ -1816,7 +1818,7 @@ func TestSetupDbConfigWithSyncFunction(t *testing.T) { RemoteConfigTlsSkipVerify: true, } } - err := dbConfig.setup(dbConfig.Name, BootstrapConfig{}, nil, nil, false) + err := dbConfig.setup(base.TestCtx(t), dbConfig.Name, BootstrapConfig{}, nil, nil, false) if test.errExpected != nil { requireErrorWithX509UnknownAuthority(t, err, test.errExpected) } else { @@ -1910,7 +1912,7 @@ func TestSetupDbConfigWithImportFilterFunction(t *testing.T) { RemoteConfigTlsSkipVerify: true, } } - err := dbConfig.setup(dbConfig.Name, BootstrapConfig{}, nil, nil, false) + err := dbConfig.setup(base.TestCtx(t), dbConfig.Name, BootstrapConfig{}, nil, nil, false) if test.errExpected != nil { requireErrorWithX509UnknownAuthority(t, err, test.errExpected) } else { @@ -2016,7 +2018,7 @@ func TestSetupDbConfigWithConflictResolutionFunction(t *testing.T) { RemoteConfigTlsSkipVerify: true, } } - err := dbConfig.setup(dbConfig.Name, BootstrapConfig{}, nil, nil, false) + err := dbConfig.setup(base.TestCtx(t), dbConfig.Name, BootstrapConfig{}, nil, nil, false) if test.errExpected != nil { requireErrorWithX509UnknownAuthority(t, err, test.errExpected) } else { @@ -2341,7 +2343,7 @@ func TestStartupConfigBcryptCostValidation(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { sc := StartupConfig{Auth: AuthConfig{BcryptCost: test.cost}} - err := sc.Validate(base.IsEnterpriseEdition()) + err := sc.Validate(base.TestCtx(t), base.IsEnterpriseEdition()) if test.expectError { require.Error(t, err) assert.Contains(t, err.Error(), errContains) @@ -2498,7 +2500,7 @@ func TestBucketCredentialsValidation(t *testing.T) { } for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - err := test.startupConfig.Validate(base.IsEnterpriseEdition()) + err := test.startupConfig.Validate(base.TestCtx(t), base.IsEnterpriseEdition()) if test.expectedError != nil { assert.Error(t, err) assert.Contains(t, err.Error(), *test.expectedError) @@ -2588,7 +2590,7 @@ func TestCollectionsValidation(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { validateOIDCConfig := false - err := test.dbConfig.validate(context.TODO(), validateOIDCConfig) + err := test.dbConfig.validate(base.TestCtx(t), validateOIDCConfig) if test.expectedError != nil { require.Error(t, err) require.Contains(t, err.Error(), *test.expectedError) diff --git a/rest/database_init_manager.go b/rest/database_init_manager.go index e5fbb9de1d..32f2d8226b 100644 --- a/rest/database_init_manager.go +++ b/rest/database_init_manager.go @@ -74,7 +74,7 @@ func (m *DatabaseInitManager) InitializeDatabase(ctx context.Context, startupCon base.InfofCtx(ctx, base.KeyAll, "Starting new async initialization for database %s ...", base.MD(dbConfig.Name)) - couchbaseCluster, err := CreateCouchbaseClusterFromStartupConfig(startupConfig, base.PerUseClusterConnections) + couchbaseCluster, err := CreateCouchbaseClusterFromStartupConfig(ctx, startupConfig, base.PerUseClusterConnections) if err != nil { return nil, err } diff --git a/rest/doc_api.go b/rest/doc_api.go index 129c0e6b95..4c278e8f0c 100644 --- a/rest/doc_api.go +++ b/rest/doc_api.go @@ -467,7 +467,7 @@ func (h *handler) handlePutDoc() error { h.setEtag(newRev) } else { // Replicator-style PUT with new_edits=false: - revisions := db.ParseRevisions(body) + revisions := db.ParseRevisions(h.ctx(), body) if revisions == nil { return base.HTTPErrorf(http.StatusBadRequest, "Bad _revisions") } @@ -515,7 +515,7 @@ func (h *handler) handlePutDocReplicator2(docid string, roundTrip bool) (err err parentRev = ifMatch } - generation, _ := db.ParseRevID(parentRev) + generation, _ := db.ParseRevID(h.ctx(), parentRev) generation++ deleted, _ := h.getOptBoolQuery("deleted", false) @@ -530,7 +530,7 @@ func (h *handler) handlePutDocReplicator2(docid string, roundTrip bool) (err err // Handle and pull out expiry if bytes.Contains(bodyBytes, []byte(db.BodyExpiry)) { - body := newDoc.Body() + body := newDoc.Body(h.ctx()) expiry, err := body.ExtractExpiry() if err != nil { return base.HTTPErrorf(http.StatusBadRequest, "Invalid expiry: %v", err) @@ -541,7 +541,7 @@ func (h *handler) handlePutDocReplicator2(docid string, roundTrip bool) (err err // Pull out attachments if bytes.Contains(bodyBytes, []byte(db.BodyAttachments)) { - body := newDoc.Body() + body := newDoc.Body(h.ctx()) newDoc.DocAttachments = db.GetBodyAttachments(body) delete(body, db.BodyAttachments) diff --git a/rest/functionsapitest/main_test.go b/rest/functionsapitest/main_test.go index 5ad62e91dd..fde7dfca9d 100644 --- a/rest/functionsapitest/main_test.go +++ b/rest/functionsapitest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package functionsapitest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,6 +19,7 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 18192} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/rest/handler.go b/rest/handler.go index 4b0cae46ad..b626ddd538 100644 --- a/rest/handler.go +++ b/rest/handler.go @@ -108,7 +108,7 @@ type handler struct { rqCtx context.Context } -type authScopeFunc func(bodyJSON []byte) (string, error) +type authScopeFunc func(ctx context.Context, bodyJSON []byte) (string, error) type handlerPrivs int @@ -373,7 +373,7 @@ func (h *handler) validateAndWriteHeaders(method handlerMethod, accessPermission base.InfofCtx(h.ctx(), base.KeyHTTP, "Error trying to get db %s: %v", base.MD(keyspaceDb), err) return err } - bucketName, _ = h.server.bucketNameFromDbName(keyspaceDb) + bucketName, _ = h.server.bucketNameFromDbName(h.ctx(), keyspaceDb) } } else { return err @@ -489,7 +489,7 @@ func (h *handler) validateAndWriteHeaders(method handlerMethod, accessPermission // The above readBody() will end up clearing the body which the later handler will require. Re-populate this // for the later handler. h.requestBody.reader = io.NopCloser(bytes.NewReader(body)) - authScope, err = h.authScopeFunc(body) + authScope, err = h.authScopeFunc(h.ctx(), body) if err != nil { return base.HTTPErrorf(http.StatusInternalServerError, "Unable to read body: %v", err) } @@ -498,7 +498,7 @@ func (h *handler) validateAndWriteHeaders(method handlerMethod, accessPermission } } - permissions, statusCode, err := checkAdminAuth(authScope, username, password, h.rq.Method, httpClient, + permissions, statusCode, err := checkAdminAuth(h.ctx(), authScope, username, password, h.rq.Method, httpClient, managementEndpoints, *h.server.Config.API.EnableAdminAuthenticationPermissionsCheck, accessPermissions, responsePermissions) if err != nil { @@ -878,7 +878,7 @@ func (h *handler) checkAdminAuthenticationOnly() (bool, error) { return false, ErrLoginRequired } - statusCode, _, err := doHTTPAuthRequest(httpClient, username, password, "POST", "/pools/default/checkPermissions", managementEndpoints, nil) + statusCode, _, err := doHTTPAuthRequest(h.ctx(), httpClient, username, password, "POST", "/pools/default/checkPermissions", managementEndpoints, nil) if err != nil { return false, base.HTTPErrorf(http.StatusInternalServerError, "Error performing HTTP auth request: %v", err) } @@ -890,9 +890,9 @@ func (h *handler) checkAdminAuthenticationOnly() (bool, error) { return true, nil } -func checkAdminAuth(bucketName, basicAuthUsername, basicAuthPassword string, attemptedHTTPOperation string, httpClient *http.Client, managementEndpoints []string, shouldCheckPermissions bool, accessPermissions []Permission, responsePermissions []Permission) (responsePermissionResults map[string]bool, statusCode int, err error) { +func checkAdminAuth(ctx context.Context, bucketName, basicAuthUsername, basicAuthPassword string, attemptedHTTPOperation string, httpClient *http.Client, managementEndpoints []string, shouldCheckPermissions bool, accessPermissions []Permission, responsePermissions []Permission) (responsePermissionResults map[string]bool, statusCode int, err error) { anyResponsePermFailed := false - permissionStatusCode, permResults, err := CheckPermissions(httpClient, managementEndpoints, bucketName, basicAuthUsername, basicAuthPassword, accessPermissions, responsePermissions) + permissionStatusCode, permResults, err := CheckPermissions(ctx, httpClient, managementEndpoints, bucketName, basicAuthUsername, basicAuthPassword, accessPermissions, responsePermissions) if err != nil { return nil, http.StatusInternalServerError, err } @@ -937,7 +937,7 @@ func checkAdminAuth(bucketName, basicAuthUsername, basicAuthPassword string, att } } - rolesStatusCode, err := CheckRoles(httpClient, managementEndpoints, basicAuthUsername, basicAuthPassword, requestRoles, bucketName) + rolesStatusCode, err := CheckRoles(ctx, httpClient, managementEndpoints, basicAuthUsername, basicAuthPassword, requestRoles, bucketName) if err != nil { return nil, http.StatusInternalServerError, err } @@ -1118,7 +1118,7 @@ func (h *handler) readSanitizeJSON(val interface{}) error { } // Expand environment variables. - content, err = expandEnv(content) + content, err = expandEnv(h.ctx(), content) if err != nil { return err } diff --git a/rest/handler_config_database.go b/rest/handler_config_database.go index 9d88d92edf..9f46691808 100644 --- a/rest/handler_config_database.go +++ b/rest/handler_config_database.go @@ -62,7 +62,7 @@ func (h *handler) mutateDbConfig(mutator func(*DbConfig) error) error { return nil, base.HTTPErrorf(http.StatusBadRequest, err.Error()) } - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(h.ctx(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (h *handler) mutateDbConfig(mutator func(*DbConfig) error) error { dbCreds := h.server.Config.DatabaseCredentials[dbName] bucketCreds := h.server.Config.BucketCredentials[bucket] - if err := updatedDbConfig.setup(dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { + if err := updatedDbConfig.setup(h.ctx(), dbName, h.server.Config.Bootstrap, dbCreds, bucketCreds, h.server.Config.IsServerless()); err != nil { return err } diff --git a/rest/importtest/import_test.go b/rest/importtest/import_test.go index 0c731d48e6..5a1a8a2fbc 100644 --- a/rest/importtest/import_test.go +++ b/rest/importtest/import_test.go @@ -1266,7 +1266,7 @@ func TestCheckForUpgradeFeed(t *testing.T) { assert.NoError(t, err, "Error writing SDK doc") // We don't have a way to wait for a upgrade that doesn't happen, but we can look for the warning that happens. - base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.Cache().NonMobileIgnoredCount.Value() }, 1) } @@ -2163,7 +2163,7 @@ func assertXattrSyncMetaRevGeneration(t *testing.T, dataStore base.DataStore, ke assert.NoError(t, err, "Error Getting Xattr") revision, ok := xattr["rev"] assert.True(t, ok) - generation, _ := db.ParseRevID(revision.(string)) + generation, _ := db.ParseRevID(base.TestCtx(t), revision.(string)) log.Printf("assertXattrSyncMetaRevGeneration generation: %d rev: %s", generation, revision) assert.True(t, generation == expectedRevGeneration) } diff --git a/rest/importtest/main_test.go b/rest/importtest/main_test.go index 62bcfebdc8..4a6d376599 100644 --- a/rest/importtest/main_test.go +++ b/rest/importtest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package importtest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,6 +19,7 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/rest/indextest/index_test.go b/rest/indextest/index_test.go index ffd2f47e2d..4194a4bb69 100644 --- a/rest/indextest/index_test.go +++ b/rest/indextest/index_test.go @@ -138,7 +138,7 @@ func TestAsyncInitializeIndexes(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Set testing callbacks for async initialization collectionCount := int64(0) @@ -253,7 +253,7 @@ func TestAsyncInitWithResync(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Seed the bucket with some documents tb := base.GetTestBucket(t) @@ -381,7 +381,7 @@ func TestAsyncOnlineOffline(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Set testing callbacks for async initialization collectionCount := int64(0) @@ -504,7 +504,7 @@ func TestAsyncCreateThenDelete(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Set testing callbacks for async initialization collectionCount := int64(0) @@ -624,7 +624,7 @@ func TestSyncOnline(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Set testing callbacks for async initialization collectionCount := int64(0) diff --git a/rest/indextest/main_test.go b/rest/indextest/main_test.go index f9f0d19162..9ae8fba408 100644 --- a/rest/indextest/main_test.go +++ b/rest/indextest/main_test.go @@ -11,13 +11,15 @@ licenses/APL2.txt. package indextest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 2048} // Do not create indexes for this test, so they are built by server_context.go - base.TestBucketPoolNoIndexes(m, tbpOptions) + base.TestBucketPoolNoIndexes(ctx, m, tbpOptions) } diff --git a/rest/main.go b/rest/main.go index 184364b6c3..764a56d77e 100644 --- a/rest/main.go +++ b/rest/main.go @@ -29,8 +29,9 @@ import ( // function directly calls this. It registers both signal and fatal panic handlers, // does the initial setup and finally starts the server. func ServerMain() { - if err := serverMain(context.Background(), os.Args); err != nil { - base.FatalfCtx(context.TODO(), "Couldn't start Sync Gateway: %v", err) + ctx := context.Background() // main context + if err := serverMain(ctx, os.Args); err != nil { + base.FatalfCtx(ctx, "Couldn't start Sync Gateway: %v", err) } } @@ -42,7 +43,7 @@ func serverMain(ctx context.Context, osArgs []string) error { base.InitializeMemoryLoggers() base.LogSyncGatewayVersion(ctx) - flagStartupConfig, fs, disablePersistentConfig, err := parseFlags(osArgs) + flagStartupConfig, fs, disablePersistentConfig, err := parseFlags(ctx, osArgs) if err != nil { // Return nil for ErrHelp so the shell exit code is 0 if err == flag.ErrHelp { @@ -78,7 +79,7 @@ func serverMainPersistentConfig(ctx context.Context, fs *flag.FlagSet, flagStart var legacyDbUsers map[string]map[string]*auth.PrincipalConfig // [db][user]PrincipleConfig var legacyDbRoles map[string]map[string]*auth.PrincipalConfig // [db][roles]PrincipleConfig if len(configPath) == 1 { - fileStartupConfig, err = LoadStartupConfigFromPath(configPath[0]) + fileStartupConfig, err = LoadStartupConfigFromPath(ctx, configPath[0]) if pkgerrors.Cause(err) == base.ErrUnknownField { // If we have an unknown field error processing config its possible that the config is a 2.x config // requiring automatic upgrade. We should attempt to perform this upgrade @@ -86,7 +87,7 @@ func serverMainPersistentConfig(ctx context.Context, fs *flag.FlagSet, flagStart base.InfofCtx(ctx, base.KeyAll, "Found unknown fields in startup config. Attempting to read as legacy config.") var upgradeError error - fileStartupConfig, disablePersistentConfigFallback, legacyDbUsers, legacyDbRoles, upgradeError = automaticConfigUpgrade(configPath[0]) + fileStartupConfig, disablePersistentConfigFallback, legacyDbUsers, legacyDbRoles, upgradeError = automaticConfigUpgrade(ctx, configPath[0]) if upgradeError != nil { // We need to validate if the error was again, an unknown field error. If this is the case its possible @@ -182,8 +183,8 @@ func getInitialStartupConfig(fileStartupConfig *StartupConfig, flagStartupConfig // automaticConfigUpgrade takes the config path of the current 2.x config and attempts to perform the update steps to // update it to a 3.x config // Returns the new startup config, a bool of whether to fallback to legacy config, map of users per database, map of roles per database, and an error -func automaticConfigUpgrade(configPath string) (sc *StartupConfig, disablePersistentConfig bool, users map[string]map[string]*auth.PrincipalConfig, roles map[string]map[string]*auth.PrincipalConfig, err error) { - legacyServerConfig, err := LoadLegacyServerConfig(configPath) +func automaticConfigUpgrade(ctx context.Context, configPath string) (sc *StartupConfig, disablePersistentConfig bool, users map[string]map[string]*auth.PrincipalConfig, roles map[string]map[string]*auth.PrincipalConfig, err error) { + legacyServerConfig, err := LoadLegacyServerConfig(ctx, configPath) if err != nil { return nil, false, nil, nil, err } @@ -192,9 +193,9 @@ func automaticConfigUpgrade(configPath string) (sc *StartupConfig, disablePersis return nil, true, users, roles, nil } - base.InfofCtx(context.Background(), base.KeyAll, "Config is a legacy config, and disable_persistent_config was not requested. Attempting automatic config upgrade.") + base.InfofCtx(ctx, base.KeyAll, "Config is a legacy config, and disable_persistent_config was not requested. Attempting automatic config upgrade.") - startupConfig, dbConfigs, err := legacyServerConfig.ToStartupConfig() + startupConfig, dbConfigs, err := legacyServerConfig.ToStartupConfig(ctx) if err != nil { return nil, false, nil, nil, err } @@ -205,7 +206,7 @@ func automaticConfigUpgrade(configPath string) (sc *StartupConfig, disablePersis } // Attempt to establish connection to server - cluster, err := CreateCouchbaseClusterFromStartupConfig(startupConfig, base.PerUseClusterConnections) + cluster, err := CreateCouchbaseClusterFromStartupConfig(ctx, startupConfig, base.PerUseClusterConnections) if err != nil { return nil, false, nil, nil, err } @@ -217,7 +218,7 @@ func automaticConfigUpgrade(configPath string) (sc *StartupConfig, disablePersis for _, dbConfig := range dbConfigs { dbc := dbConfig.ToDatabaseConfig() - dbc.Version, err = GenerateDatabaseConfigVersionID("", &dbc.DbConfig) + dbc.Version, err = GenerateDatabaseConfigVersionID(ctx, "", &dbc.DbConfig) if err != nil { return nil, false, nil, nil, err } @@ -237,16 +238,16 @@ func automaticConfigUpgrade(configPath string) (sc *StartupConfig, disablePersis configGroupID = startupConfig.Bootstrap.ConfigGroupID } - _, err = bootstrap.InsertConfig(context.Background(), *dbc.Bucket, configGroupID, dbc) + _, err = bootstrap.InsertConfig(ctx, *dbc.Bucket, configGroupID, dbc) if err != nil { // If key already exists just continue if errors.Is(err, base.ErrAlreadyExists) { - base.InfofCtx(context.Background(), base.KeyAll, "Skipping Couchbase Server persistence for config group %q in %s. Already exists.", configGroupID, base.UD(dbc.Name)) + base.InfofCtx(ctx, base.KeyAll, "Skipping Couchbase Server persistence for config group %q in %s. Already exists.", configGroupID, base.UD(dbc.Name)) continue } return nil, false, nil, nil, err } - base.InfofCtx(context.Background(), base.KeyAll, "Persisted database %s config for group %q to Couchbase Server bucket: %s", base.UD(dbc.Name), configGroupID, base.MD(*dbc.Bucket)) + base.InfofCtx(ctx, base.KeyAll, "Persisted database %s config for group %q to Couchbase Server bucket: %s", base.UD(dbc.Name), configGroupID, base.MD(*dbc.Bucket)) } // Attempt to backup current config @@ -254,11 +255,11 @@ func automaticConfigUpgrade(configPath string) (sc *StartupConfig, disablePersis // Otherwise continue with startup but don't attempt to write migrated config and log warning backupLocation, err := backupCurrentConfigFile(configPath) if err != nil { - base.WarnfCtx(context.Background(), "Unable to write config file backup: %v. Won't write backup or updated config but will continue with startup.", err) + base.WarnfCtx(ctx, "Unable to write config file backup: %v. Won't write backup or updated config but will continue with startup.", err) return startupConfig, false, users, roles, nil } - base.InfofCtx(context.Background(), base.KeyAll, "Current config backed up to %s", base.MD(backupLocation)) + base.InfofCtx(ctx, base.KeyAll, "Current config backed up to %s", base.MD(backupLocation)) // Overwrite old config with new migrated startup config jsonStartupConfig, err := json.MarshalIndent(startupConfig, "", " ") @@ -271,11 +272,11 @@ func automaticConfigUpgrade(configPath string) (sc *StartupConfig, disablePersis // Otherwise continue with startup but log warning err = os.WriteFile(configPath, jsonStartupConfig, 0644) if err != nil { - base.WarnfCtx(context.Background(), "Unable to write updated config file: %v - but will continue with startup.", err) + base.WarnfCtx(ctx, "Unable to write updated config file: %v - but will continue with startup.", err) return startupConfig, false, users, roles, nil } - base.InfofCtx(context.Background(), base.KeyAll, "Current config file overwritten by upgraded config at %s", base.MD(configPath)) + base.InfofCtx(ctx, base.KeyAll, "Current config file overwritten by upgraded config at %s", base.MD(configPath)) return startupConfig, false, users, roles, nil } @@ -367,12 +368,12 @@ func backupCurrentConfigFile(sourcePath string) (string, error) { return backupPath, nil } -func CreateCouchbaseClusterFromStartupConfig(config *StartupConfig, bucketConnectionMode base.BucketConnectionMode) (*base.CouchbaseCluster, error) { - cluster, err := base.NewCouchbaseCluster(config.Bootstrap.Server, config.Bootstrap.Username, config.Bootstrap.Password, +func CreateCouchbaseClusterFromStartupConfig(ctx context.Context, config *StartupConfig, bucketConnectionMode base.BucketConnectionMode) (*base.CouchbaseCluster, error) { + cluster, err := base.NewCouchbaseCluster(ctx, config.Bootstrap.Server, config.Bootstrap.Username, config.Bootstrap.Password, config.Bootstrap.X509CertPath, config.Bootstrap.X509KeyPath, config.Bootstrap.CACertPath, config.IsServerless(), config.BucketCredentials, config.Bootstrap.ServerTLSSkipVerify, config.Unsupported.UseXattrConfig, bucketConnectionMode) if err != nil { - base.InfofCtx(context.Background(), base.KeyConfig, "Couldn't create couchbase cluster instance: %v", err) + base.InfofCtx(ctx, base.KeyConfig, "Couldn't create couchbase cluster instance: %v", err) return nil, err } @@ -380,7 +381,7 @@ func CreateCouchbaseClusterFromStartupConfig(config *StartupConfig, bucketConnec } // parseFlags handles the parsing of legacy and persistent config flags. -func parseFlags(args []string) (flagStartupConfig *StartupConfig, fs *flag.FlagSet, disablePersistentConfig *bool, err error) { +func parseFlags(ctx context.Context, args []string) (flagStartupConfig *StartupConfig, fs *flag.FlagSet, disablePersistentConfig *bool, err error) { fs = flag.NewFlagSet(args[0], flag.ContinueOnError) // used by service scripts as a way to specify a per-distro defaultLogFilePath @@ -404,7 +405,7 @@ func parseFlags(args []string) (flagStartupConfig *StartupConfig, fs *flag.FlagS return nil, nil, nil, fmt.Errorf("error merging flags on to config: %w", err) } - err = fillConfigWithLegacyFlags(legacyConfigFlags, fs, startupConfig.Logging.Console.LogLevel != nil) + err = fillConfigWithLegacyFlags(ctx, legacyConfigFlags, fs, startupConfig.Logging.Console.LogLevel != nil) if err != nil { return nil, nil, nil, fmt.Errorf("error merging legacy flags on to config: %w", err) } diff --git a/rest/main_legacy.go b/rest/main_legacy.go index eee215f685..ca4e37ba9a 100644 --- a/rest/main_legacy.go +++ b/rest/main_legacy.go @@ -24,7 +24,7 @@ const flagDeprecated = `Flag "%s" is deprecated. Please use "%s" in future.` func legacyServerMain(ctx context.Context, osArgs []string, flagStartupConfig *StartupConfig) error { base.WarnfCtx(ctx, "Running in legacy config mode") - lc, err := setupServerConfig(osArgs) + lc, err := setupServerConfig(ctx, osArgs) if err != nil { return err } @@ -33,7 +33,7 @@ func legacyServerMain(ctx context.Context, osArgs []string, flagStartupConfig *S lc.DisablePersistentConfig = base.BoolPtr(true) - migratedStartupConfig, databases, err := lc.ToStartupConfig() + migratedStartupConfig, databases, err := lc.ToStartupConfig(ctx) if err != nil { return err } @@ -98,7 +98,7 @@ func registerLegacyFlags(config *StartupConfig, fs *flag.FlagSet) map[string]leg } } -func fillConfigWithLegacyFlags(flags map[string]legacyConfigFlag, fs *flag.FlagSet, consoleLogLevelSet bool) error { +func fillConfigWithLegacyFlags(ctx context.Context, flags map[string]legacyConfigFlag, fs *flag.FlagSet, consoleLogLevelSet bool) error { var errors *base.MultiError fs.Visit(func(f *flag.Flag) { cfgFlag, legacyFlag := flags[f.Name] @@ -108,30 +108,30 @@ func fillConfigWithLegacyFlags(flags map[string]legacyConfigFlag, fs *flag.FlagS switch f.Name { case "interface", "adminInterface", "profileInterface", "url", "certpath", "keypath", "cacertpath", "logFilePath": *cfgFlag.config.(*string) = *cfgFlag.flagValue.(*string) - base.WarnfCtx(context.Background(), flagDeprecated, "-"+f.Name, "-"+cfgFlag.supersededFlag) + base.WarnfCtx(ctx, flagDeprecated, "-"+f.Name, "-"+cfgFlag.supersededFlag) case "pretty": rCfg := reflect.ValueOf(cfgFlag.config).Elem() rFlag := reflect.ValueOf(cfgFlag.flagValue) rCfg.Set(rFlag) - base.WarnfCtx(context.Background(), flagDeprecated, "-"+f.Name, "-"+cfgFlag.supersededFlag) + base.WarnfCtx(ctx, flagDeprecated, "-"+f.Name, "-"+cfgFlag.supersededFlag) case "verbose": if *cfgFlag.flagValue.(*bool) { if consoleLogLevelSet { - base.WarnfCtx(context.Background(), `Cannot use deprecated flag "-verbose" with flag "-logging.console.log_level". To set Sync Gateway to be verbose, please use flag "-logging.console.log_level info". Ignoring flag...`) + base.WarnfCtx(ctx, `Cannot use deprecated flag "-verbose" with flag "-logging.console.log_level". To set Sync Gateway to be verbose, please use flag "-logging.console.log_level info". Ignoring flag...`) } else { *cfgFlag.config.(**base.LogLevel) = base.LogLevelPtr(base.LevelInfo) - base.WarnfCtx(context.Background(), flagDeprecated, "-"+f.Name, "-logging.console.log_level info") + base.WarnfCtx(ctx, flagDeprecated, "-"+f.Name, "-logging.console.log_level info") } } case "log": list := strings.Split(*cfgFlag.flagValue.(*string), ",") *cfgFlag.config.(*[]string) = list - base.WarnfCtx(context.Background(), flagDeprecated, "-"+f.Name, "-"+cfgFlag.supersededFlag) + base.WarnfCtx(ctx, flagDeprecated, "-"+f.Name, "-"+cfgFlag.supersededFlag) case "configServer": err := fmt.Errorf(`flag "-%s" is no longer supported and has been removed`, f.Name) errors = errors.Append(err) case "dbname", "deploymentID": - base.WarnfCtx(context.Background(), `Flag "-%s" is no longer supported and has been removed.`, f.Name) + base.WarnfCtx(ctx, `Flag "-%s" is no longer supported and has been removed.`, f.Name) } }) return errors.ErrorOrNil() diff --git a/rest/main_legacy_test.go b/rest/main_legacy_test.go index 1c5702b414..fa18f56387 100644 --- a/rest/main_legacy_test.go +++ b/rest/main_legacy_test.go @@ -42,7 +42,7 @@ func TestLegacyFlagsValid(t *testing.T) { }) require.NoError(t, err) - err = fillConfigWithLegacyFlags(flags, fs, false) + err = fillConfigWithLegacyFlags(base.TestCtx(t), flags, fs, false) assert.NoError(t, err) assert.Equal(t, "12.34.56.78", config.API.PublicInterface) @@ -70,7 +70,7 @@ func TestLegacyFlagsError(t *testing.T) { }) require.NoError(t, err) - err = fillConfigWithLegacyFlags(flags, fs, false) + err = fillConfigWithLegacyFlags(base.TestCtx(t), flags, fs, false) require.Error(t, err) assert.Contains(t, err.Error(), errorText) } diff --git a/rest/main_test.go b/rest/main_test.go index 428d3ad04e..7f4d016a82 100644 --- a/rest/main_test.go +++ b/rest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package rest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -20,8 +21,9 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 8192} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } func TestConfigOverwritesLegacyFlags(t *testing.T) { @@ -38,7 +40,7 @@ func TestConfigOverwritesLegacyFlags(t *testing.T) { "config.json", } - sc, _, _, err := parseFlags(osArgs) + sc, _, _, err := parseFlags(base.TestCtx(t), osArgs) assert.NoError(t, err) require.NotNil(t, sc) @@ -85,7 +87,7 @@ func TestParseFlags(t *testing.T) { } for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - _, _, disablePersistentConfig, err := parseFlags(append(osArgsPrefix, test.osArgs...)) + _, _, disablePersistentConfig, err := parseFlags(base.TestCtx(t), append(osArgsPrefix, test.osArgs...)) if test.expectedError != nil { require.Error(t, err) assert.Contains(t, err.Error(), *test.expectedError) diff --git a/rest/oidc_api_test.go b/rest/oidc_api_test.go index 86f7a28b9b..f2730f7f58 100644 --- a/rest/oidc_api_test.go +++ b/rest/oidc_api_test.go @@ -12,7 +12,6 @@ package rest import ( "bytes" - "context" "crypto/rand" "crypto/rsa" "encoding/json" @@ -340,7 +339,6 @@ func (s *mockAuthServer) makeToken(claimSet claimSet) (string, error) { builder := jwt.Signed(s.signer).Claims(primaryClaims).Claims(secondaryClaims) token, err := builder.CompactSerialize() if err != nil { - base.ErrorfCtx(context.TODO(), "Error serializing token: %s", err) return "", err } return token, nil @@ -1162,11 +1160,12 @@ func TestOpenIDConnectImplicitFlowReuseToken(t *testing.T) { require.NoError(t, restTester.WaitForPendingChanges()) - u, err := restTester.GetDatabase().Authenticator(base.TestCtx(t)).GetUser("foo_noah") + ctx := base.DatabaseLogCtx(base.TestCtx(t), restTester.GetDatabase().Name, nil) + u, err := restTester.GetDatabase().Authenticator(ctx).GetUser("foo_noah") require.NoError(t, err) firstJWTLastUpdated := u.JWTLastUpdated() - lastSeq, err := restTester.GetDatabase().LastSequence() + lastSeq, err := restTester.GetDatabase().LastSequence(ctx) assert.NoError(t, err) // Observing an updated user inside the changes request isn't deterministic, as it depends on the timing of the DCP feed for the principal update made during the changes request... @@ -1185,12 +1184,12 @@ func TestOpenIDConnectImplicitFlowReuseToken(t *testing.T) { assert.Equalf(t, int64(0), observedUserUpdateCount, "%d of %d changes observed user update (expected 0)", observedUserUpdateCount, numChanges) // since we made no changes to channels, we shouldn't expect the user to actually be updated with a new JWT timestamp. - u, err = restTester.GetDatabase().Authenticator(base.TestCtx(t)).GetUser("foo_noah") + u, err = restTester.GetDatabase().Authenticator(ctx).GetUser("foo_noah") require.NoError(t, err) finalJWTLastUpdated := u.JWTLastUpdated() assert.Equal(t, firstJWTLastUpdated, finalJWTLastUpdated) - finalLastSeq, err := restTester.GetDatabase().LastSequence() + finalLastSeq, err := restTester.GetDatabase().LastSequence(ctx) assert.NoError(t, err) assert.Equal(t, int64(lastSeq), int64(finalLastSeq)) @@ -1205,12 +1204,12 @@ func TestOpenIDConnectImplicitFlowReuseToken(t *testing.T) { require.NoError(t, json.Unmarshal(resp.BodyBytes(), &changesResp)) assert.Lenf(t, changesResp.Results, 1, "Expected user update on changes feed") - u, err = restTester.GetDatabase().Authenticator(base.TestCtx(t)).GetUser("foo_noah") + u, err = restTester.GetDatabase().Authenticator(ctx).GetUser("foo_noah") require.NoError(t, err) postUpdateJWTLastUpdated := u.JWTLastUpdated() base.AssertTimeGreaterThan(t, postUpdateJWTLastUpdated, finalJWTLastUpdated) - postUpdateLastSeq, err := restTester.GetDatabase().LastSequence() + postUpdateLastSeq, err := restTester.GetDatabase().LastSequence(ctx) assert.NoError(t, err) assert.Equal(t, int64(finalLastSeq+1), int64(postUpdateLastSeq)) } @@ -2481,7 +2480,7 @@ func TestOpenIDConnectProviderRemoval(t *testing.T) { go func() { serverErr <- StartServer(ctx, &startupConfig, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) defer func() { sc.Close(ctx) require.NoError(t, <-serverErr) diff --git a/rest/oidc_test_provider.go b/rest/oidc_test_provider.go index 1ec49feaed..8de22f20b6 100644 --- a/rest/oidc_test_provider.go +++ b/rest/oidc_test_provider.go @@ -496,7 +496,7 @@ func handleRefreshTokenRequest(h *handler) error { refreshToken := h.rq.FormValue("refresh_token") // extract the subject from the refresh token - subject, err := extractSubjectFromRefreshToken(refreshToken) + subject, err := extractSubjectFromRefreshToken(h.ctx(), refreshToken) // Check for subject in map of known authenticated users authState, ok := authCodeTokenMap[subject] @@ -542,16 +542,16 @@ func writeTokenResponse(h *handler, subject string, issuerUrl string, authState return nil } -func extractSubjectFromRefreshToken(refreshToken string) (string, error) { +func extractSubjectFromRefreshToken(ctx context.Context, refreshToken string) (string, error) { decodedToken, err := base64.StdEncoding.DecodeString(refreshToken) if err != nil { - base.DebugfCtx(context.Background(), base.KeyAuth, "invalid refresh token provided, error: %v", err) + base.DebugfCtx(ctx, base.KeyAuth, "invalid refresh token provided, error: %v", err) return "", base.HTTPErrorf(http.StatusBadRequest, "Invalid OIDC Refresh Token") } components := strings.Split(string(decodedToken), ":::") subject := components[0] - base.DebugfCtx(context.Background(), base.KeyAuth, "subject extracted from refresh token = %v", subject) + base.DebugfCtx(ctx, base.KeyAuth, "subject extracted from refresh token = %v", subject) if len(components) != 2 || subject == "" { return "", base.HTTPErrorf(http.StatusBadRequest, "OIDC Refresh Token does not contain subject") diff --git a/rest/oidc_test_provider_test.go b/rest/oidc_test_provider_test.go index 84c3395bc6..bf4734c7dc 100644 --- a/rest/oidc_test_provider_test.go +++ b/rest/oidc_test_provider_test.go @@ -107,8 +107,9 @@ func TestCreateJWTToken(t *testing.T) { func TestExtractSubjectFromRefreshToken(t *testing.T) { base.SetUpTestLogging(t, base.LevelDebug, base.KeyAuth) + ctx := base.TestCtx(t) // Extract subject from invalid refresh token - sub, err := extractSubjectFromRefreshToken("invalid_refresh_token") + sub, err := extractSubjectFromRefreshToken(ctx, "invalid_refresh_token") require.Error(t, err, "invalid refresh token error") assert.Contains(t, err.Error(), strconv.Itoa(http.StatusBadRequest)) assert.Empty(t, sub, "couldn't extract subject from refresh token") @@ -117,7 +118,7 @@ func TestExtractSubjectFromRefreshToken(t *testing.T) { subject := "subject" accessToken := base64.StdEncoding.EncodeToString([]byte(subject)) refreshToken := base64.StdEncoding.EncodeToString([]byte(subject + ":::" + accessToken)) - sub, err = extractSubjectFromRefreshToken(refreshToken) + sub, err = extractSubjectFromRefreshToken(ctx, refreshToken) require.NoError(t, err, "invalid refresh token error") assert.Equal(t, subject, sub) } diff --git a/rest/persistent_config_test.go b/rest/persistent_config_test.go index 663082616e..971727269b 100644 --- a/rest/persistent_config_test.go +++ b/rest/persistent_config_test.go @@ -58,7 +58,8 @@ func TestAutomaticConfigUpgrade(t *testing.T) { err := os.WriteFile(configPath, []byte(config), os.FileMode(0644)) require.NoError(t, err) - startupConfig, _, _, _, err := automaticConfigUpgrade(configPath) + ctx := base.TestCtx(t) + startupConfig, _, _, _, err := automaticConfigUpgrade(ctx, configPath) require.NoError(t, err) assert.Equal(t, "", startupConfig.Bootstrap.ConfigGroupID) @@ -97,7 +98,7 @@ func TestAutomaticConfigUpgrade(t *testing.T) { assert.Equal(t, config, string(writtenBackupFile)) - cbs, err := CreateCouchbaseClusterFromStartupConfig(startupConfig, base.PerUseClusterConnections) + cbs, err := CreateCouchbaseClusterFromStartupConfig(ctx, startupConfig, base.PerUseClusterConnections) require.NoError(t, err) bootstrapContext := &bootstrapContext{ @@ -105,7 +106,7 @@ func TestAutomaticConfigUpgrade(t *testing.T) { } var dbConfig DatabaseConfig - _, err = bootstrapContext.GetConfig(tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) + _, err = bootstrapContext.GetConfig(ctx, tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) require.NoError(t, err) assert.Equal(t, "db", dbConfig.Name) @@ -161,7 +162,7 @@ func TestAutomaticConfigUpgradeError(t *testing.T) { err := os.WriteFile(configPath, []byte(config), os.FileMode(0644)) require.NoError(t, err) - _, _, _, _, err = automaticConfigUpgrade(configPath) + _, _, _, _, err = automaticConfigUpgrade(base.TestCtx(t), configPath) assert.Error(t, err) }) } @@ -185,17 +186,18 @@ func TestUnmarshalBrokenConfig(t *testing.T) { ) RequireStatus(t, resp, http.StatusCreated) + ctx := base.TestCtx(t) // Use underlying connection to unmarshal to untyped config cnf := make(map[string]interface{}, 1) - key := PersistentConfigKey(rt.ServerContext().Config.Bootstrap.ConfigGroupID, "newdb") - cas, err := rt.ServerContext().BootstrapContext.Connection.GetMetadataDocument(tb.GetName(), key, &cnf) + key := PersistentConfigKey(ctx, rt.ServerContext().Config.Bootstrap.ConfigGroupID, "newdb") + cas, err := rt.ServerContext().BootstrapContext.Connection.GetMetadataDocument(ctx, tb.GetName(), key, &cnf) require.NoError(t, err) // Add invalid json fields to the config cnf["num_index_replicas"] = "0" // Both calls to UpdateMetadataDocument and fetchAndLoadConfigs needed to enter the broken state - _, err = rt.ServerContext().BootstrapContext.Connection.WriteMetadataDocument(tb.GetName(), key, cas, &cnf) + _, err = rt.ServerContext().BootstrapContext.Connection.WriteMetadataDocument(ctx, tb.GetName(), key, cas, &cnf) require.NoError(t, err) _, err = rt.ServerContext().fetchAndLoadConfigs(rt.Context(), false) assert.NoError(t, err) @@ -237,8 +239,9 @@ func TestAutomaticConfigUpgradeExistingConfigAndNewGroup(t *testing.T) { err := os.WriteFile(configPath, []byte(config), os.FileMode(0644)) require.NoError(t, err) + ctx := base.TestCtx(t) // Run migration once - _, _, _, _, err = automaticConfigUpgrade(configPath) + _, _, _, _, err = automaticConfigUpgrade(ctx, configPath) require.NoError(t, err) updatedConfig := fmt.Sprintf(`{ @@ -264,10 +267,10 @@ func TestAutomaticConfigUpgradeExistingConfigAndNewGroup(t *testing.T) { require.NoError(t, err) // Run migration again to ensure no error and validate it doesn't actually update db - startupConfig, _, _, _, err := automaticConfigUpgrade(updatedConfigPath) + startupConfig, _, _, _, err := automaticConfigUpgrade(ctx, updatedConfigPath) require.NoError(t, err) - cbs, err := CreateCouchbaseClusterFromStartupConfig(startupConfig, base.PerUseClusterConnections) + cbs, err := CreateCouchbaseClusterFromStartupConfig(ctx, startupConfig, base.PerUseClusterConnections) require.NoError(t, err) bootstrapContext := &bootstrapContext{ @@ -275,7 +278,7 @@ func TestAutomaticConfigUpgradeExistingConfigAndNewGroup(t *testing.T) { } var dbConfig DatabaseConfig - originalDefaultDbConfigCAS, err := bootstrapContext.GetConfig(tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) + originalDefaultDbConfigCAS, err := bootstrapContext.GetConfig(ctx, tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) assert.NoError(t, err) // Ensure that revs limit hasn't actually been set @@ -309,7 +312,7 @@ func TestAutomaticConfigUpgradeExistingConfigAndNewGroup(t *testing.T) { err = os.WriteFile(importConfigPath, []byte(importConfig), os.FileMode(0644)) require.NoError(t, err) - startupConfig, _, _, _, err = automaticConfigUpgrade(importConfigPath) + startupConfig, _, _, _, err = automaticConfigUpgrade(ctx, importConfigPath) // only supported in EE if base.IsEnterpriseEdition() { require.NoError(t, err) @@ -319,12 +322,12 @@ func TestAutomaticConfigUpgradeExistingConfigAndNewGroup(t *testing.T) { // Ensure dbConfig is saved as the specified config group ID var dbConfig DatabaseConfig - _, err = bootstrapContext.GetConfig(tb.GetName(), configUpgradeGroupID, "db", &dbConfig) + _, err = bootstrapContext.GetConfig(ctx, tb.GetName(), configUpgradeGroupID, "db", &dbConfig) assert.NoError(t, err) // Ensure default has not changed dbConfig = DatabaseConfig{} - defaultDbConfigCAS, err := bootstrapContext.GetConfig(tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) + defaultDbConfigCAS, err := bootstrapContext.GetConfig(ctx, tb.GetName(), PersistentConfigDefaultGroupID, "db", &dbConfig) assert.NoError(t, err) assert.Equal(t, originalDefaultDbConfigCAS, defaultDbConfigCAS) } else { @@ -360,7 +363,7 @@ func TestImportFilterEndpoint(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -435,7 +438,7 @@ func TestPersistentConfigWithCollectionConflicts(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -585,7 +588,7 @@ func TestPersistentConfigRegistryRollbackAfterCreateFailure(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -611,7 +614,7 @@ func TestPersistentConfigRegistryRollbackAfterCreateFailure(t *testing.T) { require.NoError(t, err) _, err = registry.upsertDatabaseConfig(ctx, groupID, config) require.NoError(t, err) - require.NoError(t, bc.setGatewayRegistry(bucketName, registry)) + require.NoError(t, bc.setGatewayRegistry(ctx, bucketName, registry)) } // set up ScopesConfigs used by tests @@ -732,7 +735,7 @@ func TestPersistentConfigRegistryRollbackAfterUpdateFailure(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -771,7 +774,7 @@ func TestPersistentConfigRegistryRollbackAfterUpdateFailure(t *testing.T) { require.NoError(t, err) _, err = registry.upsertDatabaseConfig(ctx, groupID, config) require.NoError(t, err) - require.NoError(t, bc.setGatewayRegistry(bucketName, registry)) + require.NoError(t, bc.setGatewayRegistry(ctx, bucketName, registry)) } // Case 1. GetDatabaseConfigs should roll back registry after update failure @@ -886,7 +889,7 @@ func TestPersistentConfigRegistryRollbackAfterDeleteFailure(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -924,7 +927,7 @@ func TestPersistentConfigRegistryRollbackAfterDeleteFailure(t *testing.T) { registry, err := bc.getGatewayRegistry(ctx, bucketName) require.NoError(t, err) require.NoError(t, registry.deleteDatabase(groupID, config.Name)) - require.NoError(t, bc.setGatewayRegistry(bucketName, registry)) + require.NoError(t, bc.setGatewayRegistry(ctx, bucketName, registry)) } // Case 1. Retrieval of database after delete failure should not find it (matching versions) @@ -1003,7 +1006,7 @@ func TestPersistentConfigSlowCreateFailure(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -1029,11 +1032,11 @@ func TestPersistentConfigSlowCreateFailure(t *testing.T) { require.NoError(t, err) _, err = registry.upsertDatabaseConfig(ctx, groupID, config) require.NoError(t, err) - require.NoError(t, bc.setGatewayRegistry(bucketName, registry)) + require.NoError(t, bc.setGatewayRegistry(ctx, bucketName, registry)) } completeSlowCreate := func(t *testing.T, config *DatabaseConfig) error { - _, insertError := bc.Connection.InsertMetadataDocument(bucketName, PersistentConfigKey(groupID, config.Name), config) + _, insertError := bc.Connection.InsertMetadataDocument(ctx, bucketName, PersistentConfigKey(ctx, groupID, config.Name), config) return insertError } @@ -1081,7 +1084,7 @@ func TestMigratev30PersistentConfig(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -1100,7 +1103,7 @@ func TestMigratev30PersistentConfig(t *testing.T) { Version: defaultVersion, } - _, insertError := sc.BootstrapContext.Connection.InsertMetadataDocument(bucketName, PersistentConfigKey30(groupID), defaultDatabaseConfig) + _, insertError := sc.BootstrapContext.Connection.InsertMetadataDocument(ctx, bucketName, PersistentConfigKey30(ctx, groupID), defaultDatabaseConfig) require.NoError(t, insertError) migrateErr := sc.migrateV30Configs(ctx) @@ -1114,11 +1117,11 @@ func TestMigratev30PersistentConfig(t *testing.T) { require.True(t, found) require.Equal(t, "1-abc", migratedDb.Version) // Verify legacy config has been removed - _, getError := sc.BootstrapContext.Connection.GetMetadataDocument(bucketName, PersistentConfigKey30(groupID), defaultDatabaseConfig) + _, getError := sc.BootstrapContext.Connection.GetMetadataDocument(ctx, bucketName, PersistentConfigKey30(ctx, groupID), defaultDatabaseConfig) require.Equal(t, base.ErrNotFound, getError) // Update the db in the registry, and recreate legacy config. Verify migration doesn't overwrite - _, insertError = sc.BootstrapContext.Connection.InsertMetadataDocument(bucketName, PersistentConfigKey30(groupID), defaultDatabaseConfig) + _, insertError = sc.BootstrapContext.Connection.InsertMetadataDocument(ctx, bucketName, PersistentConfigKey30(ctx, groupID), defaultDatabaseConfig) require.NoError(t, insertError) _, updateError := sc.BootstrapContext.UpdateConfig(ctx, bucketName, groupID, defaultDbName, func(bucketDbConfig *DatabaseConfig) (updatedConfig *DatabaseConfig, err error) { bucketDbConfig.Version = "2-abc" @@ -1163,7 +1166,7 @@ func TestMigratev30PersistentConfigCollision(t *testing.T) { go func() { serverErr <- StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) // Get a test bucket, and use it to create the database. tb := base.GetTestBucket(t) @@ -1189,7 +1192,7 @@ func TestMigratev30PersistentConfigCollision(t *testing.T) { DbConfig: defaultDbConfig, Version: defaultVersion, } - _, insertError := sc.BootstrapContext.Connection.InsertMetadataDocument(bucketName, PersistentConfigKey30(groupID), defaultDatabaseConfig) + _, insertError := sc.BootstrapContext.Connection.InsertMetadataDocument(ctx, bucketName, PersistentConfigKey30(ctx, groupID), defaultDatabaseConfig) require.NoError(t, insertError) migrateErr := sc.migrateV30Configs(ctx) diff --git a/rest/replicatortest/main_test.go b/rest/replicatortest/main_test.go index 5aedd19b4d..cc52b4030a 100644 --- a/rest/replicatortest/main_test.go +++ b/rest/replicatortest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package replicatortest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,6 +19,7 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{MemWatermarkThresholdMB: 8192} - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/rest/replicatortest/replicator_collection_test.go b/rest/replicatortest/replicator_collection_test.go index 1d2baab429..2b0cc8f89b 100644 --- a/rest/replicatortest/replicator_collection_test.go +++ b/rest/replicatortest/replicator_collection_test.go @@ -152,7 +152,7 @@ func TestActiveReplicatorMultiCollection(t *testing.T) { }) require.NoError(t, err) - assert.Equal(t, "", ar.GetStatus().LastSeqPull) + assert.Equal(t, "", ar.GetStatus(ctx1).LastSeqPull) // Start the replicator (implicit connect) require.NoError(t, ar.Start(ctx1)) diff --git a/rest/replicatortest/replicator_test.go b/rest/replicatortest/replicator_test.go index 5e68559bc3..13b5c702bb 100644 --- a/rest/replicatortest/replicator_test.go +++ b/rest/replicatortest/replicator_test.go @@ -2056,22 +2056,22 @@ func TestActiveReplicatorHeartbeats(t *testing.T) { }) require.NoError(t, err) - pingCountStart := base.ExpvarVar2Int(expvar.Get("goblip").(*expvar.Map).Get("sender_ping_count")) - pingGoroutinesStart := base.ExpvarVar2Int(expvar.Get("goblip").(*expvar.Map).Get("goroutines_sender_ping")) + pingCountStart := base.ExpvarVar2Int(ctx, expvar.Get("goblip").(*expvar.Map).Get("sender_ping_count")) + pingGoroutinesStart := base.ExpvarVar2Int(ctx, expvar.Get("goblip").(*expvar.Map).Get("goroutines_sender_ping")) assert.NoError(t, ar.Start(ctx)) // let some pings happen time.Sleep(time.Millisecond * 500) - pingGoroutines := base.ExpvarVar2Int(expvar.Get("goblip").(*expvar.Map).Get("goroutines_sender_ping")) + pingGoroutines := base.ExpvarVar2Int(ctx, expvar.Get("goblip").(*expvar.Map).Get("goroutines_sender_ping")) assert.Equal(t, 1+pingGoroutinesStart, pingGoroutines, "Expected ping sender goroutine to be 1 more than start") - pingCount := base.ExpvarVar2Int(expvar.Get("goblip").(*expvar.Map).Get("sender_ping_count")) + pingCount := base.ExpvarVar2Int(ctx, expvar.Get("goblip").(*expvar.Map).Get("sender_ping_count")) assert.Greaterf(t, pingCount, pingCountStart, "Expected ping count to increase since start") assert.NoError(t, ar.Stop()) - pingGoroutines = base.ExpvarVar2Int(expvar.Get("goblip").(*expvar.Map).Get("goroutines_sender_ping")) + pingGoroutines = base.ExpvarVar2Int(ctx, expvar.Get("goblip").(*expvar.Map).Get("goroutines_sender_ping")) assert.Equal(t, pingGoroutinesStart, pingGoroutines, "Expected ping sender goroutine to return to start count after stop") } @@ -2147,7 +2147,7 @@ func TestActiveReplicatorPullBasic(t *testing.T) { require.NoError(t, err) defer func() { assert.NoError(t, ar.Stop()) }() - assert.Equal(t, "", ar.GetStatus().LastSeqPull) + assert.Equal(t, "", ar.GetStatus(ctx1).LastSeqPull) // Start the replicator (implicit connect) require.NoError(t, ar.Start(ctx1)) @@ -2167,7 +2167,7 @@ func TestActiveReplicatorPullBasic(t *testing.T) { require.NoError(t, err) assert.Equal(t, "rt2", body["source"]) - assert.Equal(t, strconv.FormatUint(remoteDoc.Sequence, 10), ar.GetStatus().LastSeqPull) + assert.Equal(t, strconv.FormatUint(remoteDoc.Sequence, 10), ar.GetStatus(ctx1).LastSeqPull) } // TestActiveReplicatorPullSkippedSequence ensures that ISGR and the checkpointer are able to handle the compound sequence format appropriately. @@ -2489,13 +2489,13 @@ func TestTotalSyncTimeStat(t *testing.T) { activeRT.WaitForReplicationStatus(repName, db.ReplicationStateRunning) // wait for active replication stat to pick up the replication connection - _, ok := base.WaitForStat(func() int64 { + _, ok := base.WaitForStat(passiveRT.TB, func() int64 { return passiveRT.GetDatabase().DbStats.DatabaseStats.NumReplicationsActive.Value() }, 1) require.True(t, ok) // wait some time to wait for the stat to increment - _, ok = base.WaitForStat(func() int64 { + _, ok = base.WaitForStat(passiveRT.TB, func() int64 { return passiveRT.GetDatabase().DbStats.DatabaseStats.TotalSyncTime.Value() }, 2) require.True(t, ok) @@ -2549,16 +2549,14 @@ func TestChangesEndpointTotalSyncTime(t *testing.T) { }() // wait for active replication stat for CBL to pick up the replication connection - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.CBLReplicationPullStats.NumPullReplActiveContinuous.Value() }, 1) - require.True(t, ok) // wait some time to wait for the stat to increment - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.TotalSyncTime.Value() }, 2) - require.True(t, ok) syncTimeStat := rt.GetDatabase().DbStats.DatabaseStats.TotalSyncTime.Value() // we can't be certain how long has passed since grabbing the stat so to avoid flake here just assert the stat has incremented @@ -2867,12 +2865,13 @@ func TestActiveReplicatorPullMergeConflictingAttachments(t *testing.T) { doc, err := rt1.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalAll) require.NoError(t, err) - revGen, _ := db.ParseRevID(doc.SyncData.CurrentRev) + ctx := base.TestCtx(t) + revGen, _ := db.ParseRevID(ctx, doc.SyncData.CurrentRev) assert.Equal(t, 3, revGen) - assert.Equal(t, "merged", doc.Body()["source"].(string)) + assert.Equal(t, "merged", doc.Body(ctx)["source"].(string)) - assert.Nil(t, doc.Body()[db.BodyAttachments], "_attachments property should not be in resolved doc body") + assert.Nil(t, doc.Body(ctx)[db.BodyAttachments], "_attachments property should not be in resolved doc body") assert.Len(t, doc.SyncData.Attachments, test.expectedAttachments, "mismatch in expected number of attachments in sync data of resolved doc") for attName, att := range doc.SyncData.Attachments { @@ -3139,10 +3138,9 @@ func TestActiveReplicatorPullFromCheckpointIgnored(t *testing.T) { pullCheckpointer := ar.Pull.GetSingleCollection(t).Checkpointer - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return pullCheckpointer.Stats().AlreadyKnownSequenceCount }, numRT2DocsInitial) - assert.True(t, ok) // wait for all of the documents originally written to rt2 to arrive at rt1 changesResults, err := rt1.WaitForChanges(numRT2DocsInitial, "/{{.keyspace}}/_changes?since=0", "", true) @@ -3201,10 +3199,9 @@ func TestActiveReplicatorPullFromCheckpointIgnored(t *testing.T) { // new replicator - new checkpointer pullCheckpointer = ar.Pull.GetSingleCollection(t).Checkpointer - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return pullCheckpointer.Stats().AlreadyKnownSequenceCount }, numRT2DocsTotal-numRT2DocsInitial) - assert.True(t, ok) // Make sure we've not started any more since:0 replications on rt2 since the first one endNumChangesRequestedFromZeroTotal := rt2.GetDatabase().DbStats.CBLReplicationPull().NumPullReplSinceZero.Value() @@ -3285,7 +3282,7 @@ func TestActiveReplicatorPullOneshot(t *testing.T) { require.NoError(t, err) defer func() { assert.NoError(t, ar.Stop()) }() - assert.Equal(t, "", ar.GetStatus().LastSeqPull) + assert.Equal(t, "", ar.GetStatus(ctx1).LastSeqPull) // Start the replicator (implicit connect) assert.NoError(t, ar.Start(ctx1)) @@ -3293,7 +3290,7 @@ func TestActiveReplicatorPullOneshot(t *testing.T) { // wait for the replication to stop replicationStopped := false for i := 0; i < 100; i++ { - status := ar.GetStatus() + status := ar.GetStatus(ctx1) if status.Status == db.ReplicationStateStopped { replicationStopped = true break @@ -3372,7 +3369,7 @@ func TestActiveReplicatorPushBasic(t *testing.T) { require.NoError(t, err) defer func() { assert.NoError(t, ar.Stop()) }() - assert.Equal(t, "", ar.GetStatus().LastSeqPush) + assert.Equal(t, "", ar.GetStatus(ctx1).LastSeqPush) // Start the replicator (implicit connect) assert.NoError(t, ar.Start(ctx1)) @@ -3392,7 +3389,7 @@ func TestActiveReplicatorPushBasic(t *testing.T) { require.NoError(t, err) assert.Equal(t, "rt1", body["source"]) - assert.Equal(t, strconv.FormatUint(localDoc.Sequence, 10), ar.GetStatus().LastSeqPush) + assert.Equal(t, strconv.FormatUint(localDoc.Sequence, 10), ar.GetStatus(ctx1).LastSeqPush) } // TestActiveReplicatorPushAttachments: @@ -3925,7 +3922,7 @@ func TestActiveReplicatorPushOneshot(t *testing.T) { require.NoError(t, err) defer func() { assert.NoError(t, ar.Stop()) }() - assert.Equal(t, "", ar.GetStatus().LastSeqPush) + assert.Equal(t, "", ar.GetStatus(ctx1).LastSeqPush) // Start the replicator (implicit connect) assert.NoError(t, ar.Start(ctx1)) @@ -3933,7 +3930,7 @@ func TestActiveReplicatorPushOneshot(t *testing.T) { // wait for the replication to stop replicationStopped := false for i := 0; i < 100; i++ { - status := ar.GetStatus() + status := ar.GetStatus(ctx1) if status.Status == db.ReplicationStateStopped { replicationStopped = true break @@ -3951,7 +3948,7 @@ func TestActiveReplicatorPushOneshot(t *testing.T) { require.NoError(t, err) assert.Equal(t, "rt1", body["source"]) - assert.Equal(t, strconv.FormatUint(localDoc.Sequence, 10), ar.GetStatus().LastSeqPush) + assert.Equal(t, strconv.FormatUint(localDoc.Sequence, 10), ar.GetStatus(ctx1).LastSeqPush) } // TestActiveReplicatorPullTombstone: @@ -4143,8 +4140,8 @@ func TestActiveReplicatorPullPurgeOnRemoval(t *testing.T) { rest.RequireStatus(t, resp, http.StatusCreated) // wait for the channel removal written to rt2 to arrive at rt1 - we can't monitor _changes, because we've purged, not removed. But we can monitor the associated stat. - base.WaitForStat(func() int64 { - stats := ar.GetStatus() + base.WaitForStat(t, func() int64 { + stats := ar.GetStatus(ctx1) return stats.DocsPurged }, 1) @@ -4281,7 +4278,7 @@ func TestActiveReplicatorPullConflict(t *testing.T) { rt1revID := rest.RespRevID(t, resp) assert.Equal(t, test.localRevID, rt1revID) - customConflictResolver, err := db.NewCustomConflictResolver(test.conflictResolver, rt1.GetDatabase().Options.JavascriptTimeout) + customConflictResolver, err := db.NewCustomConflictResolver(ctx1, test.conflictResolver, rt1.GetDatabase().Options.JavascriptTimeout) require.NoError(t, err) stats, err := base.SyncGatewayStats.NewDBStats(t.Name(), false, false, false, nil, nil) require.NoError(t, err) @@ -4307,7 +4304,7 @@ func TestActiveReplicatorPullConflict(t *testing.T) { // Start the replicator (implicit connect) assert.NoError(t, ar.Start(ctx1)) - rest.WaitAndRequireCondition(t, func() bool { return ar.GetStatus().DocsRead == 1 }, "Expecting DocsRead == 1") + rest.WaitAndRequireCondition(t, func() bool { return ar.GetStatus(ctx1).DocsRead == 1 }, "Expecting DocsRead == 1") switch test.expectedResolutionType { case db.ConflictResolutionLocal: assert.Equal(t, 1, int(replicationStats.ConflictResolvedLocalCount.Value())) @@ -4341,8 +4338,9 @@ func TestActiveReplicatorPullConflict(t *testing.T) { // This is skipped for tombstone tests running with xattr as xattr tombstones don't have a body to assert // against + ctx := base.TestCtx(t) if !test.skipBodyAssertion { - assert.Equal(t, test.expectedLocalBody, doc.Body()) + assert.Equal(t, test.expectedLocalBody, doc.Body(ctx)) } log.Printf("Doc %s is %+v", docID, doc) @@ -4515,7 +4513,7 @@ func TestActiveReplicatorPushAndPullConflict(t *testing.T) { localDoc, err := rt1.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalSync) require.NoError(t, err) - customConflictResolver, err := db.NewCustomConflictResolver(test.conflictResolver, rt1.GetDatabase().Options.JavascriptTimeout) + customConflictResolver, err := db.NewCustomConflictResolver(ctx1, test.conflictResolver, rt1.GetDatabase().Options.JavascriptTimeout) require.NoError(t, err) stats, err := base.SyncGatewayStats.NewDBStats(t.Name(), false, false, false, nil, nil) @@ -4542,8 +4540,8 @@ func TestActiveReplicatorPushAndPullConflict(t *testing.T) { // Start the replicator (implicit connect) assert.NoError(t, ar.Start(ctx1)) // wait for the document originally written to rt2 to arrive at rt1. Should end up as winner under default conflict resolution - base.WaitForStat(func() int64 { - return ar.GetStatus().DocsWritten + base.WaitForStat(t, func() int64 { + return ar.GetStatus(ctx1).DocsWritten }, 1) log.Printf("========================Replication should be done, checking with changes") @@ -4564,7 +4562,8 @@ func TestActiveReplicatorPushAndPullConflict(t *testing.T) { doc, err := rt1.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalAll) require.NoError(t, err) assert.Equal(t, test.expectedRevID, doc.SyncData.CurrentRev) - assert.Equal(t, expectedLocalBody, doc.Body()) + ctx := base.TestCtx(t) + assert.Equal(t, expectedLocalBody, doc.Body(ctx)) log.Printf("Doc %s is %+v", docID, doc) log.Printf("Doc %s attachments are %+v", docID, doc.Attachments) for revID, revInfo := range doc.SyncData.History { @@ -4604,7 +4603,7 @@ func TestActiveReplicatorPushAndPullConflict(t *testing.T) { doc, err = rt2.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalAll) require.NoError(t, err) assert.Equal(t, test.expectedRevID, doc.SyncData.CurrentRev) - assert.Equal(t, expectedLocalBody, doc.Body()) + assert.Equal(t, expectedLocalBody, doc.Body(ctx)) log.Printf("Remote Doc %s is %+v", docID, doc) log.Printf("Remote Doc %s attachments are %+v", docID, doc.Attachments) for revID, revInfo := range doc.SyncData.History { @@ -5008,10 +5007,9 @@ func TestActiveReplicatorRecoverFromRemoteFlush(t *testing.T) { assert.Equal(t, startNumChangesRequestedFromZeroTotal+1, numChangesRequestedFromZeroTotal) // rev assertions - _, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return ar.Push.GetStats().SendRevCount.Value() }, startNumRevsSentTotal+1) - assert.True(t, ok) assert.Equal(t, int64(1), pushCheckpointer.Stats().ProcessedSequenceCount) assert.Equal(t, int64(1), pushCheckpointer.Stats().ExpectedSequenceCount) @@ -5075,10 +5073,9 @@ func TestActiveReplicatorRecoverFromRemoteFlush(t *testing.T) { assert.Equal(t, numChangesRequestedFromZeroTotal+1, endNumChangesRequestedFromZeroTotal) // make sure the replicator has resent the rev - _, ok = base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return ar.Push.GetStats().SendRevCount.Value() }, startNumRevsSentTotal+1) - assert.True(t, ok) assert.Equal(t, int64(1), pushCheckpointer.Stats().ProcessedSequenceCount) assert.Equal(t, int64(1), pushCheckpointer.Stats().ExpectedSequenceCount) @@ -5162,7 +5159,7 @@ func TestActiveReplicatorRecoverFromRemoteRollback(t *testing.T) { pushCheckpointer := ar.Push.GetSingleCollection(t).Checkpointer - base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return ar.Push.GetStats().SendRevCount.Value() }, 1) @@ -5198,7 +5195,7 @@ func TestActiveReplicatorRecoverFromRemoteRollback(t *testing.T) { assert.NoError(t, rt1.WaitForPendingChanges()) - base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return ar.Push.GetStats().SendRevCount.Value() }, 2) @@ -5430,7 +5427,7 @@ func TestActiveReplicatorIgnoreNoConflicts(t *testing.T) { require.NoError(t, err) defer func() { assert.NoError(t, ar.Stop()) }() - assert.Equal(t, "", ar.GetStatus().LastSeqPush) + assert.Equal(t, "", ar.GetStatus(ctx1).LastSeqPush) // Start the replicator (implicit connect) assert.NoError(t, ar.Start(ctx1)) @@ -5861,7 +5858,7 @@ func TestActiveReplicatorReconnectOnStartEventualSuccess(t *testing.T) { rest.RequireStatus(t, resp, http.StatusCreated) rest.WaitAndRequireCondition(t, func() bool { - state, errMsg := ar.State() + state, errMsg := ar.State(ctx1) if strings.TrimSpace(errMsg) != "" && !strings.Contains(errMsg, msg401) { log.Println("unexpected replicator error:", errMsg) } @@ -5939,7 +5936,7 @@ func TestActiveReplicatorReconnectSendActions(t *testing.T) { assert.NoError(t, ar.Stop()) err = rt1.WaitForCondition(func() bool { - return ar.GetStatus().Status == db.ReplicationStateStopped + return ar.GetStatus(ctx1).Status == db.ReplicationStateStopped }) require.NoError(t, err) @@ -6204,7 +6201,7 @@ func TestActiveReplicatorPullConflictReadWriteIntlProps(t *testing.T) { rt1revID := rest.RespRevID(t, resp) assert.Equal(t, test.localRevID, rt1revID) - customConflictResolver, err := db.NewCustomConflictResolver(test.conflictResolver, rt1.GetDatabase().Options.JavascriptTimeout) + customConflictResolver, err := db.NewCustomConflictResolver(ctx1, test.conflictResolver, rt1.GetDatabase().Options.JavascriptTimeout) require.NoError(t, err) dbstats, err := base.SyncGatewayStats.NewDBStats(t.Name(), false, false, false, nil, nil) require.NoError(t, err) @@ -6229,7 +6226,7 @@ func TestActiveReplicatorPullConflictReadWriteIntlProps(t *testing.T) { // Start the replicator (implicit connect) assert.NoError(t, ar.Start(ctx1)) - rest.WaitAndRequireCondition(t, func() bool { return ar.GetStatus().DocsRead == 1 }) + rest.WaitAndRequireCondition(t, func() bool { return ar.GetStatus(ctx1).DocsRead == 1 }) assert.Equal(t, 1, int(replicationStats.ConflictResolvedMergedCount.Value())) // Wait for the document originally written to rt2 to arrive at rt1. @@ -6244,8 +6241,9 @@ func TestActiveReplicatorPullConflictReadWriteIntlProps(t *testing.T) { doc, err := rt1.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalAll) require.NoError(t, err) assert.Equal(t, test.expectedLocalRevID, doc.SyncData.CurrentRev) - log.Printf("doc.Body(): %v", doc.Body()) - assert.Equal(t, test.expectedLocalBody, doc.Body()) + ctx := base.TestCtx(t) + log.Printf("doc.Body(): %v", doc.Body(ctx)) + assert.Equal(t, test.expectedLocalBody, doc.Body(ctx)) log.Printf("Doc %s is %+v", docID, doc) for revID, revInfo := range doc.SyncData.History { log.Printf("doc revision [%s]: %+v", revID, revInfo) @@ -6351,8 +6349,8 @@ func TestSGR2TombstoneConflictHandling(t *testing.T) { } compareDocRev := func(docRev, cmpRev string) (shouldRetry bool, err error, value interface{}) { - docGen, docHash := db.ParseRevID(docRev) - cmpGen, cmpHash := db.ParseRevID(cmpRev) + docGen, docHash := db.ParseRevID(base.TestCtx(t), docRev) + cmpGen, cmpHash := db.ParseRevID(base.TestCtx(t), cmpRev) if docGen == cmpGen { if docHash != cmpHash { return false, fmt.Errorf("rev generations match but hashes are different: %v, %v", docRev, cmpRev), nil @@ -6665,7 +6663,7 @@ func TestDefaultConflictResolverWithTombstoneLocal(t *testing.T) { ctx1 := rt1.Context() defaultConflictResolver, err := db.NewCustomConflictResolver( - `function(conflict) { return defaultPolicy(conflict); }`, rt1.GetDatabase().Options.JavascriptTimeout) + ctx1, `function(conflict) { return defaultPolicy(conflict); }`, rt1.GetDatabase().Options.JavascriptTimeout) require.NoError(t, err, "Error creating custom conflict resolver") sgwStats, err := base.SyncGatewayStats.NewDBStats(t.Name(), false, false, false, nil, nil) require.NoError(t, err) @@ -6695,11 +6693,12 @@ func TestDefaultConflictResolverWithTombstoneLocal(t *testing.T) { require.NoError(t, ar.Start(ctx1), "Error starting replication") defer func() { require.NoError(t, ar.Stop(), "Error stopping replication") }() + ctx := base.TestCtx(t) // Wait for the original document revision written to rt1 to arrive at rt2. rt2RevIDCreated := rt1RevIDCreated require.NoError(t, rt2.WaitForCondition(func() bool { doc, _ := rt2.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalAll) - return doc != nil && len(doc.Body()) > 0 + return doc != nil && len(doc.Body(ctx)) > 0 })) requireRevID(t, rt2, docID, rt2RevIDCreated) @@ -6820,7 +6819,7 @@ func TestDefaultConflictResolverWithTombstoneRemote(t *testing.T) { ctx1 := rt1.Context() defaultConflictResolver, err := db.NewCustomConflictResolver( - `function(conflict) { return defaultPolicy(conflict); }`, rt1.GetDatabase().Options.JavascriptTimeout) + ctx1, `function(conflict) { return defaultPolicy(conflict); }`, rt1.GetDatabase().Options.JavascriptTimeout) require.NoError(t, err, "Error creating custom conflict resolver") sgwStats, err := base.SyncGatewayStats.NewDBStats(t.Name(), false, false, false, nil, nil) require.NoError(t, err) @@ -6850,11 +6849,12 @@ func TestDefaultConflictResolverWithTombstoneRemote(t *testing.T) { require.NoError(t, ar.Start(ctx1), "Error starting replication") defer func() { require.NoError(t, ar.Stop(), "Error stopping replication") }() + ctx := base.TestCtx(t) // Wait for the original document revision written to rt2 to arrive at rt1. rt1RevIDCreated := rt2RevIDCreated require.NoError(t, rt1.WaitForCondition(func() bool { doc, _ := rt1.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalAll) - return doc != nil && len(doc.Body()) > 0 + return doc != nil && len(doc.Body(ctx)) > 0 })) requireRevID(t, rt1, docID, rt1RevIDCreated) @@ -7109,7 +7109,7 @@ func TestLocalWinsConflictResolution(t *testing.T) { remoteRevID := remoteDoc.ExtractRev() assert.Equal(t, localRevID, remoteRevID) // local and remote rev IDs must match - localGeneration, _ := db.ParseRevID(localRevID) + localGeneration, _ := db.ParseRevID(activeRT.Context(), localRevID) assert.Equal(t, test.expectedResult.generation, localGeneration) // validate expected generation assert.Equal(t, test.expectedResult.propertyValue, remoteDoc["prop"].(string)) // validate expected body assert.Equal(t, test.expectedResult.propertyValue, localDoc["prop"].(string)) // validate expected body @@ -7191,8 +7191,8 @@ func TestSendChangesToNoConflictPreHydrogenTarget(t *testing.T) { }) assert.NoError(t, err) - assert.Equal(t, db.ReplicationStateStopped, ar.GetStatus().Status) - assert.Equal(t, db.PreHydrogenTargetAllowConflictsError.Error(), ar.GetStatus().ErrorMessage) + assert.Equal(t, db.ReplicationStateStopped, ar.GetStatus(ctx1).Status) + assert.Equal(t, db.PreHydrogenTargetAllowConflictsError.Error(), ar.GetStatus(ctx1).ErrorMessage) } func TestReplicatorConflictAttachment(t *testing.T) { base.RequireNumTestBuckets(t, 2) @@ -7318,7 +7318,7 @@ func TestConflictResolveMergeWithMutatedRev(t *testing.T) { passiveDBURL, err := url.Parse(srv.URL + "/db") require.NoError(t, err) - customConflictResolver, err := db.NewCustomConflictResolver(`function(conflict){ + customConflictResolver, err := db.NewCustomConflictResolver(ctx1, `function(conflict){ var mutatedLocal = conflict.LocalDocument; mutatedLocal.source = "merged"; mutatedLocal["_deleted"] = true; @@ -7357,13 +7357,11 @@ func TestConflictResolveMergeWithMutatedRev(t *testing.T) { require.NoError(t, ar.Start(ctx1)) - val, found := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { dbRepStats, err := base.SyncGatewayStats.DbStats[t.Name()].DBReplicatorStats(ar.ID) require.NoError(t, err) return dbRepStats.PulledCount.Value() }, 1) - assert.True(t, found) - assert.Equal(t, int64(1), val) rt1.WaitForReplicationStatus(t.Name(), db.ReplicationStateStopped) } @@ -7443,7 +7441,7 @@ func TestReplicatorDoNotSendDeltaWhenSrcIsTombstone(t *testing.T) { CollectionsEnabled: !activeRT.GetDatabase().OnlyDefaultCollection(), }) require.NoError(t, err) - assert.Equal(t, "", ar.GetStatus().LastSeqPush) + assert.Equal(t, "", ar.GetStatus(activeCtx).LastSeqPush) assert.NoError(t, ar.Start(activeCtx)) // Wait for active to replicate to passive @@ -7553,7 +7551,7 @@ func TestUnprocessableDeltas(t *testing.T) { CollectionsEnabled: !activeRT.GetDatabase().OnlyDefaultCollection(), }) require.NoError(t, err) - assert.Equal(t, "", ar.GetStatus().LastSeqPush) + assert.Equal(t, "", ar.GetStatus(activeCtx).LastSeqPush) assert.NoError(t, ar.Start(activeCtx)) @@ -7664,12 +7662,12 @@ func TestReplicatorIgnoreRemovalBodies(t *testing.T) { CollectionsEnabled: !activeRT.GetDatabase().OnlyDefaultCollection(), }) require.NoError(t, err) - docWriteFailuresBefore := ar.GetStatus().DocWriteFailures + docWriteFailuresBefore := ar.GetStatus(activeCtx).DocWriteFailures assert.NoError(t, ar.Start(activeCtx)) activeRT.WaitForReplicationStatus(ar.ID, db.ReplicationStateStopped) - assert.Equal(t, docWriteFailuresBefore, ar.GetStatus().DocWriteFailures, "ISGR should ignore _remove:true bodies when purgeOnRemoval is disabled. CBG-1428 regression.") + assert.Equal(t, docWriteFailuresBefore, ar.GetStatus(activeCtx).DocWriteFailures, "ISGR should ignore _remove:true bodies when purgeOnRemoval is disabled. CBG-1428 regression.") } // CBG-1995: Test the support for using an underscore prefix in the top-level body of a document @@ -7837,11 +7835,9 @@ func TestActiveReplicatorBlipsync(t *testing.T) { assert.NoError(t, ar.Stop()) // Wait for active stat to drop to original value - numReplicationsActive, ok := base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.Database().NumReplicationsActive.Value() }, startNumReplicationsActive) - assert.True(t, ok) - assert.Equal(t, startNumReplicationsActive, numReplicationsActive) // Verify total stat has not been decremented numReplicationsTotal = rt.GetDatabase().DbStats.Database().NumReplicationsTotal.Value() @@ -7938,7 +7934,7 @@ func TestReplicatorDeprecatedCredentials(t *testing.T) { assert.Equal(t, "", config.RemoteUsername) assert.Equal(t, "", config.RemotePassword) - _, err = activeRT.GetDatabase().SGReplicateMgr.PutReplicationStatus(t.Name(), "stop") + _, err = activeRT.GetDatabase().SGReplicateMgr.PutReplicationStatus(activeRT.Context(), t.Name(), "stop") require.NoError(t, err) activeRT.WaitForReplicationStatus(t.Name(), db.ReplicationStateStopped) err = activeRT.GetDatabase().SGReplicateMgr.DeleteReplication(t.Name()) @@ -8035,7 +8031,7 @@ func TestGroupIDReplications(t *testing.T) { go func() { serverErr <- rest.StartServer(ctx, &config, sc) }() - require.NoError(t, sc.WaitForRESTAPIs()) + require.NoError(t, sc.WaitForRESTAPIs(ctx)) dbConfig := rest.DbConfig{ AutoImport: true, @@ -8110,8 +8106,7 @@ func TestGroupIDReplications(t *testing.T) { require.NoError(t, err) dbstats, err := dbContext.DbStats.DBReplicatorStats("repl") require.NoError(t, err) - actualPushed, _ := base.WaitForStat(dbstats.NumDocPushed.Value, expectedPushed) - assert.Equal(t, expectedPushed, actualPushed) + base.RequireWaitForStat(t, dbstats.NumDocPushed.Value, expectedPushed) } } } @@ -8252,8 +8247,7 @@ function (doc) { require.NoError(t, err) activeRT.WaitForReplicationStatus(replName, db.ReplicationStateRunning) - value, _ := base.WaitForStat(receiverRT.GetDatabase().DbStats.Database().NumDocWrites.Value, 6) - assert.EqualValues(t, 6, value) + base.RequireWaitForStat(t, receiverRT.GetDatabase().DbStats.Database().NumDocWrites.Value, 6) changesResults, err := receiverRT.WaitForChanges(6, "/{{.keyspace}}/_changes?since=0&include_docs=true", "", true) assert.NoError(t, err) @@ -8266,7 +8260,7 @@ function (doc) { } // Stop and remove replicator (to stop checkpointing after teardown causing panic) - _, err = activeRT.GetDatabase().SGReplicateMgr.PutReplicationStatus(replName, "stop") + _, err = activeRT.GetDatabase().SGReplicateMgr.PutReplicationStatus(activeRT.Context(), replName, "stop") require.NoError(t, err) activeRT.WaitForReplicationStatus(replName, db.ReplicationStateStopped) err = activeRT.GetDatabase().SGReplicateMgr.DeleteReplication(replName) @@ -8288,11 +8282,10 @@ function (doc) { rest.RequireStatus(t, resp, http.StatusCreated) activeRT.WaitForReplicationStatus(replName, db.ReplicationStateRunning) - value, _ = base.WaitForStat(receiverRT.GetDatabase().DbStats.Database().NumDocWrites.Value, 10) - assert.EqualValues(t, 10, value) + base.RequireWaitForStat(t, receiverRT.GetDatabase().DbStats.Database().NumDocWrites.Value, 10) // Stop and remove replicator - _, err = activeRT.GetDatabase().SGReplicateMgr.PutReplicationStatus(replName, "stop") + _, err = activeRT.GetDatabase().SGReplicateMgr.PutReplicationStatus(activeRT.Context(), replName, "stop") require.NoError(t, err) activeRT.WaitForReplicationStatus(replName, db.ReplicationStateStopped) err = activeRT.GetDatabase().SGReplicateMgr.DeleteReplication(replName) diff --git a/rest/replicatortest/replicator_test_helper.go b/rest/replicatortest/replicator_test_helper.go index 9507f96115..9f7121d815 100644 --- a/rest/replicatortest/replicator_test_helper.go +++ b/rest/replicatortest/replicator_test_helper.go @@ -90,7 +90,7 @@ func waitForTombstone(t *testing.T, rt *rest.RestTester, docID string) { require.NoError(t, rt.WaitForPendingChanges()) require.NoError(t, rt.WaitForCondition(func() bool { doc, _ := rt.GetSingleTestDatabaseCollection().GetDocument(base.TestCtx(t), docID, db.DocUnmarshalAll) - return doc.IsDeleted() && len(doc.Body()) == 0 + return doc.IsDeleted() && len(doc.Body(base.TestCtx(t))) == 0 })) } diff --git a/rest/rest_tester_cluster_test.go b/rest/rest_tester_cluster_test.go index a7c91dfaa3..df63313591 100644 --- a/rest/rest_tester_cluster_test.go +++ b/rest/rest_tester_cluster_test.go @@ -202,7 +202,7 @@ func TestPersistentDbConfigWithInvalidUpsert(t *testing.T) { assert.NotContains(t, string(resp.BodyBytes()), `"revs_limit":`) // remove the db config directly from the bucket - docID := PersistentConfigKey(*rtc.config.groupID, db) + docID := PersistentConfigKey(base.TestCtx(t), *rtc.config.groupID, db) // metadata store _, err = rtc.testBucket.DefaultDataStore().Remove(docID, 0) require.NoError(t, err) diff --git a/rest/revocation_test.go b/rest/revocation_test.go index 43378d25e2..9f58a09984 100644 --- a/rest/revocation_test.go +++ b/rest/revocation_test.go @@ -118,7 +118,8 @@ func (tester *ChannelRevocationTester) removeUserChannel(user, channel string) { } func (tester *ChannelRevocationTester) fillToSeq(seq uint64) { - currentSeq, err := tester.restTester.GetDatabase().LastSequence() + ctx := base.DatabaseLogCtx(base.TestCtx(tester.restTester.TB), tester.restTester.GetDatabase().Name, nil) + currentSeq, err := tester.restTester.GetDatabase().LastSequence(ctx) require.NoError(tester.test, err) loopCount := seq - currentSeq @@ -2504,7 +2505,7 @@ func TestBlipRevokeNonExistentRole(t *testing.T) { require.NoError(t, bt.StartPull()) // in the failing case we'll panic before hitting this - base.WaitForStat(func() int64 { + base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.CBLReplicationPull().NumPullReplCaughtUp.Value() }, 1) } diff --git a/rest/server_context.go b/rest/server_context.go index a4f55d3a8f..52896e4963 100644 --- a/rest/server_context.go +++ b/rest/server_context.go @@ -172,20 +172,20 @@ func NewServerContext(ctx context.Context, config *StartupConfig, persistentConf return sc } -func (sc *ServerContext) WaitForRESTAPIs() error { +func (sc *ServerContext) WaitForRESTAPIs(ctx context.Context) error { timeout := 30 * time.Second interval := time.Millisecond * 100 numAttempts := int(timeout / interval) - timeoutCtx, cancelFn := context.WithTimeout(context.Background(), timeout) + ctx, cancelFn := context.WithTimeout(ctx, timeout) defer cancelFn() - err, _ := base.RetryLoopCtx("Wait for REST APIs", func() (shouldRetry bool, err error, value interface{}) { + err, _ := base.RetryLoop(ctx, "Wait for REST APIs", func() (shouldRetry bool, err error, value interface{}) { sc.lock.RLock() defer sc.lock.RUnlock() if len(sc._httpServers) == 3 { return false, nil, nil } return true, nil, nil - }, base.CreateSleeperFunc(numAttempts, int(interval.Milliseconds())), timeoutCtx) + }, base.CreateSleeperFunc(numAttempts, int(interval.Milliseconds()))) return err } @@ -224,7 +224,7 @@ func (sc *ServerContext) Close(ctx context.Context) { for _, db := range sc.databases_ { db.Close(ctx) - _ = db.EventMgr.RaiseDBStateChangeEvent(db.Name, "offline", "Database context closed", &sc.Config.API.AdminInterface) + _ = db.EventMgr.RaiseDBStateChangeEvent(ctx, db.Name, "offline", "Database context closed", &sc.Config.API.AdminInterface) } sc.databases_ = nil sc.invalidDatabaseConfigTracking.dbNames = nil @@ -372,7 +372,7 @@ func (sc *ServerContext) PostUpgrade(ctx context.Context, preview bool) (postUpg for name, database := range sc.databases_ { // View cleanup - removedDDocs, _ := database.RemoveObsoleteDesignDocs(preview) + removedDDocs, _ := database.RemoveObsoleteDesignDocs(ctx, preview) // Index cleanup var removedIndexes []string @@ -592,6 +592,7 @@ func (sc *ServerContext) _getOrAddDatabaseFromConfig(ctx context.Context, config } err, _ = base.RetryLoop( + ctx, fmt.Sprintf("waiting for %s.%s.%s to exist", base.MD(bucket.GetName()), base.MD(scopeName), base.MD(collectionName)), waitForCollection, base.CreateMaxDoublingSleeperFunc(30, 10, 1000)) @@ -600,7 +601,7 @@ func (sc *ServerContext) _getOrAddDatabaseFromConfig(ctx context.Context, config return nil, fmt.Errorf("error attempting to create/update database: %w", err) } // Check if scope/collection specified exists. Will enter retry loop if connection unsuccessful - if err := base.WaitUntilDataStoreExists(dataStore); err != nil { + if err := base.WaitUntilDataStoreExists(ctx, dataStore); err != nil { return nil, fmt.Errorf("attempting to create/update database with a scope/collection that is not found") } metadataIndexOption := db.IndexesWithoutMetadata @@ -738,7 +739,7 @@ func (sc *ServerContext) _getOrAddDatabaseFromConfig(ctx context.Context, config config.Unsupported.WarningThresholds.ChannelNameSize = &base.DefaultWarnThresholdChannelNameSize } - autoImport, err := config.AutoImportEnabled() + autoImport, err := config.AutoImportEnabled(ctx) if err != nil { return nil, err } @@ -888,7 +889,7 @@ func (sc *ServerContext) _getOrAddDatabaseFromConfig(ctx context.Context, config } atomic.StoreUint32(&dbcontext.State, db.DBOffline) - _ = dbcontext.EventMgr.RaiseDBStateChangeEvent(dbName, "offline", stateChangeMsg, &sc.Config.API.AdminInterface) + _ = dbcontext.EventMgr.RaiseDBStateChangeEvent(ctx, dbName, "offline", stateChangeMsg, &sc.Config.API.AdminInterface) return dbcontext, nil } @@ -907,7 +908,7 @@ func (sc *ServerContext) _getOrAddDatabaseFromConfig(ctx context.Context, config return nil, err } atomic.StoreUint32(&dbcontext.State, db.DBOnline) - _ = dbcontext.EventMgr.RaiseDBStateChangeEvent(dbName, "online", stateChangeMsg, &sc.Config.API.AdminInterface) + _ = dbcontext.EventMgr.RaiseDBStateChangeEvent(ctx, dbName, "online", stateChangeMsg, &sc.Config.API.AdminInterface) return dbcontext, nil } else { // If asyncOnline is requested, set state to Starting and spawn a separate goroutine to wait for init completion @@ -957,7 +958,7 @@ func (sc *ServerContext) asyncDatabaseOnline(nonCancelCtx base.NonCancellableCon } stateChangeMsg := "DB loaded from config" - _ = dbc.EventMgr.RaiseDBStateChangeEvent(dbc.Name, "online", stateChangeMsg, &sc.Config.API.AdminInterface) + _ = dbc.EventMgr.RaiseDBStateChangeEvent(ctx, dbc.Name, "online", stateChangeMsg, &sc.Config.API.AdminInterface) } func (sc *ServerContext) GetDbVersion(dbName string) string { @@ -1223,30 +1224,24 @@ func dbcOptionsFromConfig(ctx context.Context, sc *ServerContext, config *DbConf } // Per-database console logging config overrides - if config.Logging != nil && config.Logging.Console != nil { - logKey := base.ToLogKey(config.Logging.Console.LogKeys) - contextOptions.LoggingConfig.Console = &base.DbConsoleLogConfig{ - LogLevel: config.Logging.Console.LogLevel, - LogKeys: &logKey, - } - } + contextOptions.LoggingConfig.Console = config.toDbConsoleLogConfig(ctx) if sc.Config.Unsupported.UserQueries != nil && *sc.Config.Unsupported.UserQueries { var err error if config.UserFunctions != nil { - contextOptions.UserFunctions, err = functions.CompileFunctions(*config.UserFunctions) + contextOptions.UserFunctions, err = functions.CompileFunctions(ctx, *config.UserFunctions) if err != nil { return contextOptions, err } } if config.GraphQL != nil { - contextOptions.GraphQL, err = functions.CompileGraphQL(config.GraphQL) + contextOptions.GraphQL, err = functions.CompileGraphQL(ctx, config.GraphQL) if err != nil { return contextOptions, err } } } else if config.UserFunctions != nil || config.GraphQL != nil { - base.WarnfCtx(context.TODO(), `Database config options "functions" and "graphql" ignored because unsupported.user_queries feature flag is not enabled`) + base.WarnfCtx(ctx, `Database config options "functions" and "graphql" ignored because unsupported.user_queries feature flag is not enabled`) } return contextOptions, nil @@ -1284,7 +1279,7 @@ func (sc *ServerContext) TakeDbOnline(nonContextStruct base.NonCancellableContex // validateMetadataStore will func validateMetadataStore(ctx context.Context, metadataStore base.DataStore) error { // Check if scope/collection specified exists. Will enter retry loop if connection unsuccessful - err := base.WaitUntilDataStoreExists(metadataStore) + err := base.WaitUntilDataStoreExists(ctx, metadataStore) if err == nil { return nil } @@ -1356,7 +1351,7 @@ func (sc *ServerContext) initEventHandlers(ctx context.Context, dbcontext *db.Da if config.Unsupported != nil { insecureSkipVerify = config.Unsupported.RemoteConfigTlsSkipVerify } - filter, err := loadJavaScript(conf.Filter, insecureSkipVerify) + filter, err := loadJavaScript(ctx, conf.Filter, insecureSkipVerify) if err != nil { return &JavaScriptLoadError{ JSLoadType: WebhookFilter, @@ -1385,7 +1380,7 @@ func (sc *ServerContext) initEventHandlers(ctx context.Context, dbcontext *db.Da base.WarnfCtx(ctx, "Error parsing wait_for_process from config, using default %s", err) } } - dbcontext.EventMgr.Start(config.EventHandlers.MaxEventProc, int(customWaitTime)) + dbcontext.EventMgr.Start(ctx, config.EventHandlers.MaxEventProc, int(customWaitTime)) return nil } @@ -1409,12 +1404,12 @@ func (sc *ServerContext) processEventHandlersForEvent(ctx context.Context, event for _, event := range events { switch event.HandlerType { case "webhook": - wh, err := db.NewWebhook(event.Url, event.Filter, event.Timeout, event.Options) + wh, err := db.NewWebhook(ctx, event.Url, event.Filter, event.Timeout, event.Options) if err != nil { base.WarnfCtx(ctx, "Error creating webhook %v", err) return err } - dbcontext.EventMgr.RegisterEventHandler(wh, eventType) + dbcontext.EventMgr.RegisterEventHandler(ctx, wh, eventType) default: return errors.New(fmt.Sprintf("Unknown event handler type %s", event.HandlerType)) } @@ -1488,7 +1483,7 @@ func (sc *ServerContext) _suspendDatabase(ctx context.Context, dbName string) er } bucket := dbCtx.Bucket.GetName() - base.InfofCtx(context.TODO(), base.KeyAll, "Suspending db %q (bucket %q)", base.MD(dbName), base.MD(bucket)) + base.InfofCtx(ctx, base.KeyAll, "Suspending db %q (bucket %q)", base.MD(dbName), base.MD(bucket)) if !sc._unloadDatabase(ctx, dbName) { return base.ErrNotFound @@ -1525,7 +1520,7 @@ func (sc *ServerContext) _unsuspendDatabase(ctx context.Context, dbName string) bucket = *dbConfig.Bucket } - cas, err := sc.BootstrapContext.GetConfig(bucket, sc.Config.Bootstrap.ConfigGroupID, dbName, &dbConfig.DatabaseConfig) + cas, err := sc.BootstrapContext.GetConfig(ctx, bucket, sc.Config.Bootstrap.ConfigGroupID, dbName, &dbConfig.DatabaseConfig) if err == base.ErrNotFound { // Database no longer exists, so clean up dbConfigs base.InfofCtx(ctx, base.KeyConfig, "Database %q has been removed while suspended from bucket %q", base.MD(dbName), base.MD(bucket)) @@ -1599,7 +1594,7 @@ func (sc *ServerContext) logStats(ctx context.Context) error { base.WarnfCtx(ctx, "Error getting sigar based system resource stats: %v", err) } - sc.updateCalculatedStats() + sc.updateCalculatedStats(ctx) // Create wrapper expvar map in order to add a timestamp field for logging purposes currentTime := time.Now() wrapper := statsWrapper{ @@ -1633,13 +1628,13 @@ func (sc *ServerContext) logNetworkInterfaceStats(ctx context.Context) { } // Updates stats that are more efficient to calculate at stats collection time -func (sc *ServerContext) updateCalculatedStats() { +func (sc *ServerContext) updateCalculatedStats(ctx context.Context) { sc.lock.RLock() defer sc.lock.RUnlock() for _, dbContext := range sc.databases_ { dbState := atomic.LoadUint32(&dbContext.State) if dbState == db.DBOnline { - dbContext.UpdateCalculatedStats() + dbContext.UpdateCalculatedStats(ctx) } } @@ -1651,7 +1646,7 @@ func initClusterAgent(ctx context.Context, clusterAddress, clusterUser, clusterP return nil, err } - tlsRootCAProvider, err := base.GoCBCoreTLSRootCAProvider(tlsSkipVerify, caCertPath) + tlsRootCAProvider, err := base.GoCBCoreTLSRootCAProvider(ctx, tlsSkipVerify, caCertPath) if err != nil { return nil, err } @@ -1719,7 +1714,7 @@ func initClusterAgent(ctx context.Context, clusterAddress, clusterUser, clusterP // initializeGoCBAgent Obtains a gocb agent from the current server connection. Requires the agent to be closed after use. // Uses retry loop func (sc *ServerContext) initializeGoCBAgent(ctx context.Context) (*gocbcore.Agent, error) { - err, a := base.RetryLoop("Initialize Cluster Agent", func() (shouldRetry bool, err error, value interface{}) { + err, a := base.RetryLoop(ctx, "Initialize Cluster Agent", func() (shouldRetry bool, err error, value interface{}) { agent, err := initClusterAgent( ctx, sc.Config.Bootstrap.Server, sc.Config.Bootstrap.Username, sc.Config.Bootstrap.Password, @@ -1747,7 +1742,7 @@ func (sc *ServerContext) initializeGoCBAgent(ctx context.Context) (*gocbcore.Age // without any x509 keypair included in the tls config. This client can be used to perform basic // authentication checks against the server. // Client creation otherwise clones the approach used by gocb. -func (sc *ServerContext) initializeNoX509HttpClient() (*http.Client, error) { +func (sc *ServerContext) initializeNoX509HttpClient(ctx context.Context) (*http.Client, error) { // baseTlsConfig defines the tlsConfig except for ServerName, which is updated based // on addr in DialTLS @@ -1755,7 +1750,7 @@ func (sc *ServerContext) initializeNoX509HttpClient() (*http.Client, error) { MinVersion: tls.VersionTLS12, } var rootCAs *x509.CertPool - tlsRootCAProvider, err := base.GoCBCoreTLSRootCAProvider(sc.Config.Bootstrap.ServerTLSSkipVerify, sc.Config.Bootstrap.CACertPath) + tlsRootCAProvider, err := base.GoCBCoreTLSRootCAProvider(ctx, sc.Config.Bootstrap.ServerTLSSkipVerify, sc.Config.Bootstrap.CACertPath) if err != nil { return nil, err } @@ -1839,10 +1834,10 @@ func (sc *ServerContext) ObtainManagementEndpointsAndHTTPClient() ([]string, *ht // For Authorization it checks whether the user has any ONE of the supplied accessPermissions // If the user is authorized it will also check the responsePermissions and return the results for these. These can be // used by handlers to determine different responses based on the permissions the user has. -func CheckPermissions(httpClient *http.Client, managementEndpoints []string, bucketName, username, password string, accessPermissions []Permission, responsePermissions []Permission) (statusCode int, permissionResults map[string]bool, err error) { +func CheckPermissions(ctx context.Context, httpClient *http.Client, managementEndpoints []string, bucketName, username, password string, accessPermissions []Permission, responsePermissions []Permission) (statusCode int, permissionResults map[string]bool, err error) { combinedPermissions := append(accessPermissions, responsePermissions...) body := []byte(strings.Join(FormatPermissionNames(combinedPermissions, bucketName), ",")) - statusCode, bodyResponse, err := doHTTPAuthRequest(httpClient, username, password, "POST", "/pools/default/checkPermissions", managementEndpoints, body) + statusCode, bodyResponse, err := doHTTPAuthRequest(ctx, httpClient, username, password, "POST", "/pools/default/checkPermissions", managementEndpoints, body) if err != nil { return http.StatusInternalServerError, nil, err } @@ -1888,8 +1883,8 @@ func CheckPermissions(httpClient *http.Client, managementEndpoints []string, buc return http.StatusForbidden, nil, nil } -func CheckRoles(httpClient *http.Client, managementEndpoints []string, username, password string, requestedRoles []RouteRole, bucketName string) (statusCode int, err error) { - statusCode, bodyResponse, err := doHTTPAuthRequest(httpClient, username, password, "GET", "/whoami", managementEndpoints, nil) +func CheckRoles(ctx context.Context, httpClient *http.Client, managementEndpoints []string, username, password string, requestedRoles []RouteRole, bucketName string) (statusCode int, err error) { + statusCode, bodyResponse, err := doHTTPAuthRequest(ctx, httpClient, username, password, "GET", "/whoami", managementEndpoints, nil) if err != nil { return http.StatusInternalServerError, err } @@ -1928,7 +1923,7 @@ func CheckRoles(httpClient *http.Client, managementEndpoints []string, username, return http.StatusForbidden, nil } -func doHTTPAuthRequest(httpClient *http.Client, username, password, method, path string, endpoints []string, requestBody []byte) (statusCode int, responseBody []byte, err error) { +func doHTTPAuthRequest(ctx context.Context, httpClient *http.Client, username, password, method, path string, endpoints []string, requestBody []byte) (statusCode int, responseBody []byte, err error) { retryCount := 0 worker := func() (shouldRetry bool, err error, value interface{}) { @@ -1955,7 +1950,7 @@ func doHTTPAuthRequest(httpClient *http.Client, username, password, method, path return false, err, nil } - err, result := base.RetryLoop("", worker, base.CreateSleeperFunc(10, 100)) + err, result := base.RetryLoop(ctx, "", worker, base.CreateSleeperFunc(10, 100)) if err != nil { return 0, nil, err } @@ -1997,14 +1992,14 @@ func (sc *ServerContext) initializeCouchbaseServerConnections(ctx context.Contex sc.GoCBAgent = goCBAgent //sc.DatabaseInitManager.cluster = goCBAgent. - sc.NoX509HTTPClient, err = sc.initializeNoX509HttpClient() + sc.NoX509HTTPClient, err = sc.initializeNoX509HttpClient(ctx) if err != nil { return err } // Fetch database configs from bucket and start polling for new buckets and config updates. if sc.persistentConfig { - couchbaseCluster, err := CreateCouchbaseClusterFromStartupConfig(sc.Config, base.CachedClusterConnections) + couchbaseCluster, err := CreateCouchbaseClusterFromStartupConfig(ctx, sc.Config, base.CachedClusterConnections) if err != nil { return err } diff --git a/rest/server_context_test.go b/rest/server_context_test.go index b189d1d5cc..f5fa2c7fd2 100644 --- a/rest/server_context_test.go +++ b/rest/server_context_test.go @@ -321,7 +321,7 @@ func TestObtainManagementEndpointsFromServerContextWithX509(t *testing.T) { require.NoError(t, err) svrctx.GoCBAgent = goCBAgent - noX509HttpClient, err := svrctx.initializeNoX509HttpClient() + noX509HttpClient, err := svrctx.initializeNoX509HttpClient(ctx) require.NoError(t, err) svrctx.NoX509HTTPClient = noX509HttpClient @@ -383,7 +383,7 @@ func TestStartAndStopHTTPServers(t *testing.T) { require.NoError(t, <-serveErr) }() - err, _ = base.RetryLoop("try http request", func() (shouldRetry bool, err error, value interface{}) { + err, _ = base.RetryLoop(ctx, "try http request", func() (shouldRetry bool, err error, value interface{}) { resp, err := http.Get("http://" + config.API.PublicInterface) if err != nil { return true, err, nil @@ -445,7 +445,7 @@ func TestTLSSkipVerifyCombinations(t *testing.T) { }, } - err := startupConfig.Validate(base.IsEnterpriseEdition()) + err := startupConfig.Validate(base.TestCtx(t), base.IsEnterpriseEdition()) if test.expectError { assert.Error(t, err) assert.Contains(t, err.Error(), errorText) @@ -569,7 +569,7 @@ func TestUseTLSServer(t *testing.T) { t.Run(test.name, func(t *testing.T) { sc := StartupConfig{Bootstrap: BootstrapConfig{Server: test.server, UseTLSServer: &test.useTLSServer}} - err := sc.Validate(base.IsEnterpriseEdition()) + err := sc.Validate(base.TestCtx(t), base.IsEnterpriseEdition()) if test.expectedError != nil { require.Error(t, err) @@ -752,7 +752,7 @@ func TestLogFlush(t *testing.T) { } sleeper := base.CreateSleeperFunc(200, 100) - err, _ = base.RetryLoop("Wait for log files", worker, sleeper) + err, _ = base.RetryLoop(ctx, "Wait for log files", worker, sleeper) assert.NoError(t, err) if !assert.Len(t, files, testCase.ExpectedLogFileCount) { // Try to figure who is writing to the files diff --git a/rest/serverless_test.go b/rest/serverless_test.go index 23ab48c109..dbaaa32629 100644 --- a/rest/serverless_test.go +++ b/rest/serverless_test.go @@ -435,7 +435,7 @@ func TestServerlessFetchConfigsLimited(t *testing.T) { require.NoError(t, err) // Update database config in the bucket (increment version) newCas, err := sc.BootstrapContext.UpdateConfig(ctx, tb.GetName(), sc.Config.Bootstrap.ConfigGroupID, "db", func(bucketDbConfig *DatabaseConfig) (updatedConfig *DatabaseConfig, err error) { - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(rt.Context(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } @@ -464,7 +464,7 @@ func TestServerlessFetchConfigsLimited(t *testing.T) { // Update database config in the bucket again to test caching disable case newCas, err = sc.BootstrapContext.UpdateConfig(ctx, tb.GetName(), sc.Config.Bootstrap.ConfigGroupID, "db", func(bucketDbConfig *DatabaseConfig) (updatedConfig *DatabaseConfig, err error) { - bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(bucketDbConfig.Version, &bucketDbConfig.DbConfig) + bucketDbConfig.Version, err = GenerateDatabaseConfigVersionID(rt.Context(), bucketDbConfig.Version, &bucketDbConfig.DbConfig) if err != nil { return nil, err } diff --git a/rest/session_test.go b/rest/session_test.go index d897d7a366..e5649b393f 100644 --- a/rest/session_test.go +++ b/rest/session_test.go @@ -221,7 +221,7 @@ func TestLogin(t *testing.T) { rt := NewRestTester(t, nil) defer rt.Close() - a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions(rt.Context())) user, err := a.GetUser("") assert.NoError(t, err) user.SetDisabled(true) @@ -264,7 +264,7 @@ func TestCustomCookieName(t *testing.T) { defer rt.Close() // Disable guest user - a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions(rt.Context())) user, err := a.GetUser("") assert.NoError(t, err) user.SetDisabled(true) @@ -306,7 +306,7 @@ func TestSessionTtlGreaterThan30Days(t *testing.T) { rt := NewRestTester(t, nil) defer rt.Close() - a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + a := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions(rt.Context())) user, err := a.GetUser("") assert.NoError(t, err) user.SetDisabled(true) @@ -681,7 +681,7 @@ func TestSessionExpirationDateTimeFormat(t *testing.T) { rt := NewRestTester(t, nil) defer rt.Close() - authenticator := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + authenticator := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions(rt.Context())) user, err := authenticator.NewUser("alice", "letMe!n", channels.BaseSetOf(t, "*")) assert.NoError(t, err, "Couldn't create new user") assert.NoError(t, authenticator.Save(user), "Couldn't save new user") diff --git a/rest/sgcollect.go b/rest/sgcollect.go index 33e9140910..13a4537f2a 100644 --- a/rest/sgcollect.go +++ b/rest/sgcollect.go @@ -313,7 +313,7 @@ func sgCollectPaths() (sgBinary, sgCollectBinary string, err error) { return "", "", err } - logCtx := context.Background() + logCtx := context.TODO() // this is global variable at init, we can't pass it in easily hasBinDir := true sgCollectPath := filepath.Join("tools", "sgcollect_info") diff --git a/rest/sync_fn_test.go b/rest/sync_fn_test.go index ce0544ca93..cae830962e 100644 --- a/rest/sync_fn_test.go +++ b/rest/sync_fn_test.go @@ -276,7 +276,7 @@ func TestSyncFnDocBodyPropertiesSwitchActiveTombstone(t *testing.T) { rev3aID := RespRevID(t, resp) // rev 2-b - _, rev1Hash := db.ParseRevID(rev1ID) + _, rev1Hash := db.ParseRevID(rt.Context(), rev1ID) resp = rt.SendAdminRequest("PUT", fmt.Sprintf("/{{.keyspace}}/%s?new_edits=false", testDocID), `{"`+db.BodyRevisions+`":{"start":2,"ids":["b", "`+rev1Hash+`"]}}`) RequireStatus(t, resp, 201) rev2bID := RespRevID(t, resp) diff --git a/rest/upgradetest/main_test.go b/rest/upgradetest/main_test.go index edcd30ceac..2c82d51413 100644 --- a/rest/upgradetest/main_test.go +++ b/rest/upgradetest/main_test.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package upgradetest import ( + "context" "testing" "github.com/couchbase/sync_gateway/base" @@ -18,9 +19,10 @@ import ( ) func TestMain(m *testing.M) { + ctx := context.Background() // start of test process tbpOptions := base.TestBucketPoolOptions{ MemWatermarkThresholdMB: 8192, UseDefaultScope: true, } - db.TestBucketPoolWithIndexes(m, tbpOptions) + db.TestBucketPoolWithIndexes(ctx, m, tbpOptions) } diff --git a/rest/user_api_test.go b/rest/user_api_test.go index 3a0594f2d1..17df7a0f4a 100644 --- a/rest/user_api_test.go +++ b/rest/user_api_test.go @@ -488,7 +488,7 @@ func TestUserAndRoleResponseContentType(t *testing.T) { assert.Equal(t, "application/json", response.Header().Get("Content-Type")) // Create a new user and save to database to create user session. - authenticator := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions()) + authenticator := auth.NewAuthenticator(rt.MetadataStore(), nil, rt.GetDatabase().AuthenticatorOptions(rt.Context())) user, err := authenticator.NewUser("eve", "cGFzc3dvcmQ=", channels.BaseSetOf(t, "*")) assert.NoError(t, err, "Couldn't create new user") assert.NoError(t, authenticator.Save(user), "Couldn't save new user") diff --git a/rest/utilities_testing.go b/rest/utilities_testing.go index 43a1b2ded7..350343b9c9 100644 --- a/rest/utilities_testing.go +++ b/rest/utilities_testing.go @@ -249,7 +249,7 @@ func (rt *RestTester) Bucket() base.Bucket { sc.Unsupported.UserQueries = base.BoolPtr(rt.EnableUserQueries) // Allow EE-only config even in CE for testing using group IDs. - if err := sc.Validate(true); err != nil { + if err := sc.Validate(base.TestCtx(rt.TB), true); err != nil { panic("invalid RestTester StartupConfig: " + err.Error()) } @@ -325,7 +325,7 @@ func (rt *RestTester) Bucket() base.Bucket { rt.DatabaseConfig.SGReplicateEnabled = base.BoolPtr(rt.RestTesterConfig.SgReplicateEnabled) - autoImport, _ := rt.DatabaseConfig.AutoImportEnabled() + autoImport, _ := rt.DatabaseConfig.AutoImportEnabled(ctx) if rt.DatabaseConfig.ImportPartitions == nil && base.TestUseXattrs() && base.IsEnterpriseEdition() && autoImport { // Speed up test setup - most tests don't need more than one partition given we only have one node rt.DatabaseConfig.ImportPartitions = base.Uint16Ptr(1) @@ -773,7 +773,7 @@ func (rt *RestTester) WaitForChanges(numChangesExpected int, changesURL, usernam sleeper := base.CreateSleeperFunc(200, 100) - err, changesVal := base.RetryLoop("Wait for changes", waitForChangesWorker, sleeper) + err, changesVal := base.RetryLoop(rt.Context(), "Wait for changes", waitForChangesWorker, sleeper) if err != nil { return changes, err } @@ -804,7 +804,7 @@ func (rt *RestTester) WaitForConditionWithOptions(successFunc func() bool, maxNu } sleeper := base.CreateSleeperFunc(maxNumAttempts, timeToSleepMs) - err, _ := base.RetryLoop("Wait for condition options", waitForSuccess, sleeper) + err, _ := base.RetryLoop(rt.Context(), "Wait for condition options", waitForSuccess, sleeper) if err != nil { return err } @@ -814,7 +814,7 @@ func (rt *RestTester) WaitForConditionWithOptions(successFunc func() bool, maxNu func (rt *RestTester) WaitForConditionShouldRetry(conditionFunc func() (shouldRetry bool, err error, value interface{}), maxNumAttempts, timeToSleepMs int) error { sleeper := base.CreateSleeperFunc(maxNumAttempts, timeToSleepMs) - err, _ := base.RetryLoop("Wait for condition options", conditionFunc, sleeper) + err, _ := base.RetryLoop(rt.Context(), "Wait for condition options", conditionFunc, sleeper) if err != nil { return err } @@ -885,7 +885,7 @@ func (rt *RestTester) WaitForNViewResults(numResultsExpected int, viewUrlPath st description := fmt.Sprintf("Wait for %d view results for query to %v", numResultsExpected, viewUrlPath) sleeper := base.CreateSleeperFunc(200, 100) - err, returnVal := base.RetryLoop(description, worker, sleeper) + err, returnVal := base.RetryLoop(rt.Context(), description, worker, sleeper) if err != nil { return sgbucket.ViewResult{}, err @@ -918,7 +918,7 @@ func (rt *RestTester) WaitForViewAvailable(viewURLPath string) (err error) { description := "Wait for view readiness" sleeper := base.CreateSleeperFunc(200, 100) - err, _ = base.RetryLoop(description, worker, sleeper) + err, _ = base.RetryLoop(rt.Context(), description, worker, sleeper) return err @@ -984,13 +984,13 @@ func (rt *RestTester) SendAdminRequestWithHeaders(method, resource string, body func (rt *RestTester) PutDocumentWithRevID(docID string, newRevID string, parentRevID string, body db.Body) (response *TestResponse, err error) { requestBody := body.ShallowCopy() - newRevGeneration, newRevDigest := db.ParseRevID(newRevID) + newRevGeneration, newRevDigest := db.ParseRevID(base.TestCtx(rt.TB), newRevID) revisions := make(map[string]interface{}) revisions["start"] = newRevGeneration ids := []string{newRevDigest} if parentRevID != "" { - _, parentDigest := db.ParseRevID(parentRevID) + _, parentDigest := db.ParseRevID(base.TestCtx(rt.TB), parentRevID) ids = append(ids, parentDigest) } revisions["ids"] = ids @@ -1069,7 +1069,7 @@ func (rt *RestTester) GetDocumentSequence(key string) (sequence uint64) { func (rt *RestTester) ReplacePerBucketCredentials(config base.PerBucketCredentialsConfig) { rt.ServerContext().Config.BucketCredentials = config // Update the CouchbaseCluster to include the new bucket credentials - couchbaseCluster, err := CreateCouchbaseClusterFromStartupConfig(rt.ServerContext().Config, base.PerUseClusterConnections) + couchbaseCluster, err := CreateCouchbaseClusterFromStartupConfig(base.TestCtx(rt.TB), rt.ServerContext().Config, base.PerUseClusterConnections) require.NoError(rt.TB, err) rt.ServerContext().BootstrapContext.Connection = couchbaseCluster } @@ -1828,6 +1828,7 @@ func (bt *BlipTester) WaitForNumChanges(numChangesExpected int) (changes [][]int } _, rawChanges := base.RetryLoop( + bt.restTester.Context(), "WaitForNumChanges", retryWorker, base.CreateDoublingSleeperFunc(10, 10), @@ -1895,6 +1896,7 @@ func (bt *BlipTester) WaitForNumDocsViaChanges(numDocsExpected int) (docs map[st } _, allDocs := base.RetryLoop( + bt.restTester.Context(), "WaitForNumDocsViaChanges", retryWorker, base.CreateDoublingSleeperFunc(20, 10), @@ -2281,7 +2283,7 @@ func WaitAndAssertConditionTimeout(t *testing.T, timeout time.Duration, fn func( func WaitAndAssertBackgroundManagerState(t testing.TB, expected db.BackgroundProcessState, getStateFunc func(t testing.TB) db.BackgroundProcessState) bool { t.Helper() - err, actual := base.RetryLoop(t.Name()+"-WaitAndAssertBackgroundManagerState", func() (shouldRetry bool, err error, value interface{}) { + err, actual := base.RetryLoop(base.TestCtx(t), t.Name()+"-WaitAndAssertBackgroundManagerState", func() (shouldRetry bool, err error, value interface{}) { actual := getStateFunc(t) return expected != actual, nil, actual }, base.CreateMaxDoublingSleeperFunc(30, 100, 1000)) @@ -2290,7 +2292,7 @@ func WaitAndAssertBackgroundManagerState(t testing.TB, expected db.BackgroundPro func WaitAndAssertBackgroundManagerExpiredHeartbeat(t testing.TB, bm *db.BackgroundManager) bool { t.Helper() - err, b := base.RetryLoop(t.Name()+"-assertNoHeartbeatDoc", func() (shouldRetry bool, err error, value interface{}) { + err, b := base.RetryLoop(base.TestCtx(t), t.Name()+"-assertNoHeartbeatDoc", func() (shouldRetry bool, err error, value interface{}) { b, err := bm.GetHeartbeatDoc(t) return !base.IsDocNotFoundError(err), err, b }, base.CreateMaxDoublingSleeperFunc(30, 100, 1000)) diff --git a/rest/utilities_testing_resttester.go b/rest/utilities_testing_resttester.go index bdc09e005c..459ff2b4be 100644 --- a/rest/utilities_testing_resttester.go +++ b/rest/utilities_testing_resttester.go @@ -299,7 +299,7 @@ func (rt *RestTester) InsertDbConfigToBucket(config *DatabaseConfig, bucketName } func (rt *RestTester) PersistDbConfigToBucket(dbConfig DbConfig, bucketName string) { - version, err := GenerateDatabaseConfigVersionID("", &dbConfig) + version, err := GenerateDatabaseConfigVersionID(rt.Context(), "", &dbConfig) require.NoError(rt.TB, err) metadataID, metadataIDError := rt.ServerContext().BootstrapContext.ComputeMetadataIDForDbConfig(base.TestCtx(rt.TB), &dbConfig) diff --git a/rest/utilities_testing_user.go b/rest/utilities_testing_user.go index 4abc965bb9..f880ed1b0f 100644 --- a/rest/utilities_testing_user.go +++ b/rest/utilities_testing_user.go @@ -40,7 +40,7 @@ func MakeUser(t *testing.T, httpClient *http.Client, serverURL, username, passwo return false, err, resp } - err, resp := base.RetryLoop("Admin Auth testing MakeUser", retryWorker, base.CreateSleeperFunc(10, 100)) + err, resp := base.RetryLoop(base.TestCtx(t), "Admin Auth testing MakeUser", retryWorker, base.CreateSleeperFunc(10, 100)) require.NoError(t, err) if resp.(*http.Response).StatusCode != http.StatusOK { @@ -67,7 +67,7 @@ func DeleteUser(t *testing.T, httpClient *http.Client, serverURL, username strin return false, err, resp } - err, resp := base.RetryLoop("Admin Auth testing DeleteUser", retryWorker, base.CreateSleeperFunc(10, 100)) + err, resp := base.RetryLoop(base.TestCtx(t), "Admin Auth testing DeleteUser", retryWorker, base.CreateSleeperFunc(10, 100)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.(*http.Response).StatusCode) diff --git a/rest/view_api_test.go b/rest/view_api_test.go index 14c3c65a28..d9ce8b2ed8 100644 --- a/rest/view_api_test.go +++ b/rest/view_api_test.go @@ -598,7 +598,7 @@ func TestPostInstallCleanup(t *testing.T) { defer rt.Close() // Cleanup existing design docs - _, err := rt.GetDatabase().RemoveObsoleteDesignDocs(false) + _, err := rt.GetDatabase().RemoveObsoleteDesignDocs(base.TestCtx(t), false) require.NoError(t, err) bucket := rt.Bucket() diff --git a/rest/x509_test.go b/rest/x509_test.go index 2125a3af15..ddcd393268 100644 --- a/rest/x509_test.go +++ b/rest/x509_test.go @@ -158,13 +158,13 @@ func setupX509Tests(t *testing.T, useIPAddress bool) (testBucket *base.TestBucke usingDocker, dockerName := base.TestUseCouchbaseServerDockerName() if usingDocker { - err = loadCertsIntoCouchbaseServerDocker(*testURL, ca, nodePair, dockerName) + err = loadCertsIntoCouchbaseServerDocker(base.TestCtx(t), *testURL, ca, nodePair, dockerName) } else { isLocalX509, localUserName := base.TestX509LocalServer() if isLocalX509 { - err = loadCertsIntoLocalCouchbaseServer(*testURL, ca, nodePair, localUserName) + err = loadCertsIntoLocalCouchbaseServer(base.TestCtx(t), *testURL, ca, nodePair, localUserName) } else { - err = loadCertsIntoCouchbaseServer(*testURL, ca, nodePair) + err = loadCertsIntoCouchbaseServer(base.TestCtx(t), *testURL, ca, nodePair) } } require.NoError(t, err) diff --git a/rest/x509_utils_test.go b/rest/x509_utils_test.go index 37fe010be9..20c7f0003e 100644 --- a/rest/x509_utils_test.go +++ b/rest/x509_utils_test.go @@ -253,45 +253,43 @@ func x509SSHUsername() string { } // loadCertsIntoCouchbaseServer will upload the given certs into Couchbase Server (via SSH and the REST API) -func loadCertsIntoCouchbaseServer(couchbaseServerURL url.URL, ca *caPair, node *nodePair) error { +func loadCertsIntoCouchbaseServer(ctx context.Context, couchbaseServerURL url.URL, ca *caPair, node *nodePair) error { // Copy node cert and key via SSH sshRemoteHost := x509SSHUsername() + "@" + couchbaseServerURL.Hostname() err := sshCopyFileAsExecutable(node.PEMFilepath, sshRemoteHost, "/opt/couchbase/var/lib/couchbase/inbox") if err != nil { return err } - logCtx := context.Background() - base.DebugfCtx(logCtx, base.KeyAll, "copied x509 node chain.pem to integration test server") + base.DebugfCtx(ctx, base.KeyAll, "copied x509 node chain.pem to integration test server") err = sshCopyFileAsExecutable(node.KeyFilePath, sshRemoteHost, "/opt/couchbase/var/lib/couchbase/inbox") if err != nil { return err } - base.DebugfCtx(logCtx, base.KeyAll, "copied x509 node pkey.key to integration test server") + base.DebugfCtx(ctx, base.KeyAll, "copied x509 node pkey.key to integration test server") - return uploadCACertViaREST(couchbaseServerURL, ca) + return uploadCACertViaREST(ctx, couchbaseServerURL, ca) } // loadCertsIntoCouchbaseServer will upload the given certs into Couchbase Server (via SSH and the REST API) -func loadCertsIntoCouchbaseServerDocker(couchbaseServerURL url.URL, ca *caPair, node *nodePair, containerName string) error { +func loadCertsIntoCouchbaseServerDocker(ctx context.Context, couchbaseServerURL url.URL, ca *caPair, node *nodePair, containerName string) error { err := copyLocalFileIntoDocker(containerName, node.PEMFilepath, "/opt/couchbase/var/lib/couchbase/inbox") if err != nil { return err } - logCtx := context.Background() - base.DebugfCtx(logCtx, base.KeyAll, "copied x509 node chain.pem to integration test server") + base.DebugfCtx(ctx, base.KeyAll, "copied x509 node chain.pem to integration test server") err = copyLocalFileIntoDocker(containerName, node.KeyFilePath, "/opt/couchbase/var/lib/couchbase/inbox") if err != nil { return err } - base.DebugfCtx(logCtx, base.KeyAll, "copied x509 node pkey.key to integration test server") + base.DebugfCtx(ctx, base.KeyAll, "copied x509 node pkey.key to integration test server") - return uploadCACertViaREST(couchbaseServerURL, ca) + return uploadCACertViaREST(ctx, couchbaseServerURL, ca) } // loadCertsIntoLocalCouchbaseServer will upload the given certs into Couchbase Server (via SSH and the REST API) -func loadCertsIntoLocalCouchbaseServer(couchbaseServerURL url.URL, ca *caPair, node *nodePair, localMacOSUser string) error { +func loadCertsIntoLocalCouchbaseServer(ctx context.Context, couchbaseServerURL url.URL, ca *caPair, node *nodePair, localMacOSUser string) error { localMacOSCouchbaseServerInbox := "/Users/" + localMacOSUser + "/Library/Application Support/Couchbase/var/lib/couchbase/inbox" @@ -300,18 +298,17 @@ func loadCertsIntoLocalCouchbaseServer(couchbaseServerURL url.URL, ca *caPair, n if err != nil { return err } - logCtx := context.Background() - base.DebugfCtx(logCtx, base.KeyAll, "copied x509 node chain.pem to integration test server") + base.DebugfCtx(ctx, base.KeyAll, "copied x509 node chain.pem to integration test server") err = copyLocalFile(node.KeyFilePath, localMacOSCouchbaseServerInbox) if err != nil { return err } - base.DebugfCtx(logCtx, base.KeyAll, "copied x509 node pkey.key to integration test server") - return uploadCACertViaREST(couchbaseServerURL, ca) + base.DebugfCtx(ctx, base.KeyAll, "copied x509 node pkey.key to integration test server") + return uploadCACertViaREST(ctx, couchbaseServerURL, ca) } -func uploadCACertViaREST(couchbaseServerURL url.URL, ca *caPair) error { +func uploadCACertViaREST(ctx context.Context, couchbaseServerURL url.URL, ca *caPair) error { restAPIURL := basicAuthRESTPIURLFromConnstrHost(couchbaseServerURL) // Upload the CA cert via the REST API @@ -327,8 +324,7 @@ func uploadCACertViaREST(couchbaseServerURL url.URL, ca *caPair) error { if resp.StatusCode != http.StatusOK { return fmt.Errorf("couldn't uploadClusterCA: expected %d status code but got %d: %s", http.StatusOK, resp.StatusCode, respBody) } - logCtx := context.Background() - base.DebugfCtx(logCtx, base.KeyAll, "uploaded ca.pem to Couchbase Server") + base.DebugfCtx(ctx, base.KeyAll, "uploaded ca.pem to Couchbase Server") // Make CBS read the newly uploaded certs resp, err = http.Post(restAPIURL.String()+"/node/controller/reloadCertificate", "", nil) @@ -343,9 +339,9 @@ func uploadCACertViaREST(couchbaseServerURL url.URL, ca *caPair) error { if resp.StatusCode != http.StatusOK { return fmt.Errorf("couldn't reloadCertificate: expected %d status code but got %d: %s", http.StatusOK, resp.StatusCode, respBody) } - base.DebugfCtx(logCtx, base.KeyAll, "triggered reload of certificates on Couchbase Server") + base.DebugfCtx(ctx, base.KeyAll, "triggered reload of certificates on Couchbase Server") - if err := enableX509ClientCertsInCouchbaseServer(restAPIURL); err != nil { + if err := enableX509ClientCertsInCouchbaseServer(ctx, restAPIURL); err != nil { return err } @@ -353,7 +349,7 @@ func uploadCACertViaREST(couchbaseServerURL url.URL, ca *caPair) error { } // couchbaseNodeConfiguredHostname returns the Couchbase node name for the given URL. -func couchbaseNodeConfiguredHostname(restAPIURL url.URL) (string, error) { +func couchbaseNodeConfiguredHostname(ctx context.Context, restAPIURL url.URL) (string, error) { resp, err := http.Get(restAPIURL.String() + "/pools/default") if err != nil { return "", err @@ -374,7 +370,7 @@ func couchbaseNodeConfiguredHostname(restAPIURL url.URL) (string, error) { if err := e.Decode(&respJSON); err != nil { return "", err } - base.DebugfCtx(context.Background(), base.KeyAll, "enabled X.509 client certs in Couchbase Server") + base.DebugfCtx(ctx, base.KeyAll, "enabled X.509 client certs in Couchbase Server") for _, n := range respJSON.NodesExt { if n.ThisNode { @@ -389,7 +385,7 @@ func couchbaseNodeConfiguredHostname(restAPIURL url.URL) (string, error) { func assertHostnameMatch(t *testing.T, couchbaseServerURL *url.URL) { restAPIURL := basicAuthRESTPIURLFromConnstrHost(*couchbaseServerURL) - nodeHostname, err := couchbaseNodeConfiguredHostname(restAPIURL) + nodeHostname, err := couchbaseNodeConfiguredHostname(base.TestCtx(t), restAPIURL) require.NoError(t, err) if nodeHostname != restAPIURL.Host { t.Fatal("Test requires " + base.TestEnvCouchbaseServerUrl + " to be the same as the Couchbase Server node hostname...\n\n" + @@ -397,7 +393,7 @@ func assertHostnameMatch(t *testing.T, couchbaseServerURL *url.URL) { } } -func enableX509ClientCertsInCouchbaseServer(restAPIURL url.URL) error { +func enableX509ClientCertsInCouchbaseServer(ctx context.Context, restAPIURL url.URL) error { clientAuthSettings := bytes.NewBufferString(` { "state": "enable", @@ -421,7 +417,7 @@ func enableX509ClientCertsInCouchbaseServer(restAPIURL url.URL) error { if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { return fmt.Errorf("couldn't configure clientCertAuth: expected %d or %d status codes but got %d: %s", http.StatusOK, http.StatusAccepted, resp.StatusCode, respBody) } - base.DebugfCtx(context.Background(), base.KeyAll, "enabled X.509 client certs in Couchbase Server") + base.DebugfCtx(ctx, base.KeyAll, "enabled X.509 client certs in Couchbase Server") return nil }