Skip to content

Commit

Permalink
fix: re-initialize auth provider on endpoint change and logout (#698)
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiandoetsch authored Oct 11, 2024
1 parent 0ba9106 commit 68b6aa8
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 10 deletions.
6 changes: 2 additions & 4 deletions application/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ func (c *CliSettings) DefaultBinaryInstallPath() string {
}

type Config struct {
scrubbingDict frameworkLogging.ScrubbingDict
scrubbingWriter zerolog.LevelWriter
cliSettings *CliSettings
configFile string
Expand Down Expand Up @@ -240,7 +239,6 @@ func NewFromExtension(engine workflow.Engine) *Config {
func newConfig(engine workflow.Engine) *Config {
c := &Config{}
c.folderAdditionalParameters = make(map[string][]string)
c.scrubbingDict = frameworkLogging.ScrubbingDict{}
c.logger = getNewScrubbingLogger(c)
c.cliSettings = NewCliSettings(c)
c.automaticAuthentication = true
Expand Down Expand Up @@ -305,7 +303,7 @@ func initWorkFlowEngine(c *Config) {
func getNewScrubbingLogger(c *Config) *zerolog.Logger {
c.m.Lock()
defer c.m.Unlock()
c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(logging.New(nil), c.scrubbingDict)
c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(logging.New(nil), make(frameworkLogging.ScrubbingDict))
writer := c.getConsoleWriter(c.scrubbingWriter)
logger := zerolog.New(writer).With().Timestamp().Str("separator", "-").Str("method", "").Str("ext", "").Logger()
return &logger
Expand Down Expand Up @@ -598,7 +596,7 @@ func (c *Config) ConfigureLogging(server types.Server) {
defer c.m.Unlock()

// overwrite a potential already existing writer, so we have the latest settings
c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(zerolog.MultiLevelWriter(writers...), c.scrubbingDict)
c.scrubbingWriter = frameworkLogging.NewScrubbingWriter(zerolog.MultiLevelWriter(writers...), make(frameworkLogging.ScrubbingDict))
writer := c.getConsoleWriter(c.scrubbingWriter)
logger := zerolog.New(writer).With().Timestamp().Str("separator", "-").Str("method", "").Str("ext", "").Logger().Level(logLevel)
c.logger = &logger
Expand Down
7 changes: 5 additions & 2 deletions application/server/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ package server

import (
"context"
"github.com/snyk/snyk-ls/internal/product"
"os"
"reflect"
"strconv"
"strings"

"github.com/snyk/snyk-ls/internal/product"

"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler"

Expand Down Expand Up @@ -259,7 +260,9 @@ func updateApiEndpoints(c *config.Config, settings types.Settings, initializatio
endpointsUpdated := c.UpdateApiEndpoints(snykApiUrl)

if endpointsUpdated && !initialization {
di.AuthenticationService().Logout(context.Background())
authService := di.AuthenticationService()
authService.Logout(context.Background())
authService.ConfigureProviders(c)
workspace.Get().Clear()
}

Expand Down
2 changes: 1 addition & 1 deletion application/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ func Test_initialize_handlesUntrustedFoldersWhenAutomaticAuthentication(t *testi
}

assert.Nil(t, err)
assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second, time.Millisecond)
assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second*5, time.Millisecond)
}

func Test_initialize_handlesUntrustedFoldersWhenAuthenticated(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions application/server/trust_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ package server

import (
"context"
"github.com/snyk/snyk-ls/domain/snyk/scanner"
"os"
"testing"
"time"

"github.com/snyk/snyk-ls/domain/snyk/scanner"

"github.com/creachadair/jrpc2"
"github.com/stretchr/testify/assert"

Expand Down Expand Up @@ -148,7 +149,6 @@ func Test_MultipleFoldersInRootDirWithOnlyOneTrusted(t *testing.T) {

c := config.CurrentConfig()
c.SetTrustedFolderFeatureEnabled(true)
c.SetTrustedFolderFeatureEnabled(true)

fakeAuthenticationProvider := di.AuthenticationService().Provider().(*authentication.FakeAuthenticationProvider)
fakeAuthenticationProvider.IsAuthenticated = true
Expand Down Expand Up @@ -180,7 +180,7 @@ func Test_MultipleFoldersInRootDirWithOnlyOneTrusted(t *testing.T) {
}

assert.NoError(t, err)
assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second, time.Millisecond)
assert.Eventually(t, func() bool { return checkTrustMessageRequest(jsonRPCRecorder) }, time.Second*10, time.Millisecond)
}

func checkTrustMessageRequest(jsonRPCRecorder *testutil.JsonRPCRecorder) bool {
Expand Down
1 change: 1 addition & 0 deletions infrastructure/authentication/auth_service_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func (a *AuthenticationServiceImpl) Logout(ctx context.Context) {
a.errorReporter.CaptureError(err)
}
a.UpdateCredentials("", true)
a.ConfigureProviders(a.c)
}

// IsAuthenticated returns true if the token is verified
Expand Down

0 comments on commit 68b6aa8

Please sign in to comment.