From 68b6aa874231a5ee6865ab49fdefad06b199b724 Mon Sep 17 00:00:00 2001 From: Bastian Doetsch Date: Fri, 11 Oct 2024 15:16:03 +0200 Subject: [PATCH] fix: re-initialize auth provider on endpoint change and logout (#698) --- application/config/config.go | 6 ++---- application/server/configuration.go | 7 +++++-- application/server/server_test.go | 2 +- application/server/trust_test.go | 6 +++--- infrastructure/authentication/auth_service_impl.go | 1 + 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/application/config/config.go b/application/config/config.go index 5085d6ab3..48a98772e 100644 --- a/application/config/config.go +++ b/application/config/config.go @@ -153,7 +153,6 @@ func (c *CliSettings) DefaultBinaryInstallPath() string { } type Config struct { - scrubbingDict frameworkLogging.ScrubbingDict scrubbingWriter zerolog.LevelWriter cliSettings *CliSettings configFile string @@ -240,7 +239,6 @@ func NewFromExtension(engine workflow.Engine) *Config { func newConfig(engine workflow.Engine) *Config { c := &Config{} c.folderAdditionalParameters = make(map[string][]string) - c.scrubbingDict = frameworkLogging.ScrubbingDict{} c.logger = getNewScrubbingLogger(c) c.cliSettings = NewCliSettings(c) c.automaticAuthentication = true @@ -305,7 +303,7 @@ func initWorkFlowEngine(c *Config) { func getNewScrubbingLogger(c *Config) *zerolog.Logger { c.m.Lock() defer c.m.Unlock() - c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(logging.New(nil), c.scrubbingDict) + c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(logging.New(nil), make(frameworkLogging.ScrubbingDict)) writer := c.getConsoleWriter(c.scrubbingWriter) logger := zerolog.New(writer).With().Timestamp().Str("separator", "-").Str("method", "").Str("ext", "").Logger() return &logger @@ -598,7 +596,7 @@ func (c *Config) ConfigureLogging(server types.Server) { defer c.m.Unlock() // overwrite a potential already existing writer, so we have the latest settings - c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(zerolog.MultiLevelWriter(writers...), c.scrubbingDict) + c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(zerolog.MultiLevelWriter(writers...), make(frameworkLogging.ScrubbingDict)) writer := c.getConsoleWriter(c.scrubbingWriter) logger := zerolog.New(writer).With().Timestamp().Str("separator", "-").Str("method", "").Str("ext", "").Logger().Level(logLevel) c.logger = &logger diff --git a/application/server/configuration.go b/application/server/configuration.go index 3aed7706d..fdbd5b6b3 100644 --- a/application/server/configuration.go +++ b/application/server/configuration.go @@ -18,12 +18,13 @@ package server import ( "context" - "github.com/snyk/snyk-ls/internal/product" "os" "reflect" "strconv" "strings" + "github.com/snyk/snyk-ls/internal/product" + "github.com/creachadair/jrpc2" "github.com/creachadair/jrpc2/handler" @@ -259,7 +260,9 @@ func updateApiEndpoints(c *config.Config, settings types.Settings, initializatio endpointsUpdated := c.UpdateApiEndpoints(snykApiUrl) if endpointsUpdated && !initialization { - di.AuthenticationService().Logout(context.Background()) + authService := di.AuthenticationService() + authService.Logout(context.Background()) + authService.ConfigureProviders(c) workspace.Get().Clear() } diff --git a/application/server/server_test.go b/application/server/server_test.go index debbc54cd..971e1e3b3 100644 --- a/application/server/server_test.go +++ b/application/server/server_test.go @@ -653,7 +653,7 @@ func Test_initialize_handlesUntrustedFoldersWhenAutomaticAuthentication(t *testi } assert.Nil(t, err) - assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second, time.Millisecond) + assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second*5, time.Millisecond) } func Test_initialize_handlesUntrustedFoldersWhenAuthenticated(t *testing.T) { diff --git a/application/server/trust_test.go b/application/server/trust_test.go index 33785d729..cff8d5a96 100644 --- a/application/server/trust_test.go +++ b/application/server/trust_test.go @@ -18,11 +18,12 @@ package server import ( "context" - "github.com/snyk/snyk-ls/domain/snyk/scanner" "os" "testing" "time" + "github.com/snyk/snyk-ls/domain/snyk/scanner" + "github.com/creachadair/jrpc2" "github.com/stretchr/testify/assert" @@ -148,7 +149,6 @@ func Test_MultipleFoldersInRootDirWithOnlyOneTrusted(t *testing.T) { c := config.CurrentConfig() c.SetTrustedFolderFeatureEnabled(true) - c.SetTrustedFolderFeatureEnabled(true) fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider) fakeAuthenticationProvider.IsAuthenticated = true @@ -180,7 +180,7 @@ func Test_MultipleFoldersInRootDirWithOnlyOneTrusted(t *testing.T) { } assert.NoError(t, err) - assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second, time.Millisecond) + assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second*10, time.Millisecond) } func checkTrustMessageRequest(jsonRPCRecorder *testutil.JsonRPCRecorder) bool { diff --git a/infrastructure/authentication/auth_service_impl.go b/infrastructure/authentication/auth_service_impl.go index 70a7d6ca8..c1f06de75 100644 --- a/infrastructure/authentication/auth_service_impl.go +++ b/infrastructure/authentication/auth_service_impl.go @@ -108,6 +108,7 @@ func (a *AuthenticationServiceImpl) Logout(ctx context.Context) { a.errorReporter.CaptureError(err) } a.UpdateCredentials("", true) + a.ConfigureProviders(a.c) } // IsAuthenticated returns true if the token is verified