Skip to content

Commit

Permalink
Client + Exports + Python: Add a startup boolean to getconfig
Browse files Browse the repository at this point in the history
To be used for autoconnect on startup. If autoconnect on startup set
to true

This ignores any callbacks that require user input (profile,
authorization & location callbacks)
  • Loading branch information
jwijenbergh committed Aug 16, 2023
1 parent e7a4427 commit d9e7837
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 124 deletions.
29 changes: 19 additions & 10 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,12 @@ func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error {
return nil
}

func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) error {
func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool, startup bool) error {
// location
if srv.NeedsLocation() {
if startup {
return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no secure internet location is found. Please manually connect again", server.Name(srv))
}
err := c.locationCallback(ck)
if err != nil {
return i18nerr.Wrap(err, "The secure internet location could not be set")
Expand All @@ -398,6 +401,9 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool)
log.Logger.Debugf("failed to get tokens from client: %v", err)
}
if server.NeedsRelogin(context.Background(), srv) || forceauth {
if startup {
return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but you need to authorizate again. Please manually connect again", server.Name(srv))
}
// mark organizations as expired if the server is a secure internet server
b, berr := srv.Base()
if berr == nil && b.Type == srvtypes.TypeSecureInternet {
Expand All @@ -416,13 +422,16 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool)
return nil
}

func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server) error {
func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server, startup bool) error {
vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard)
if err != nil {
log.Logger.Warningf("failed to determine whether the current protocol is valid with error: %v", err)
return err
}
if !vp {
if startup {
return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no valid profiles were found. Please manually connect again", server.Name(srv))
}
vps, err := server.ValidProfiles(srv, c.SupportsWireguard)
if err != nil {
return i18nerr.Wrapf(err, "No suitable profiles could be found")
Expand Down Expand Up @@ -527,7 +536,7 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
}

// callbacks
err = c.callbacks(ck, srv, false)
err = c.callbacks(ck, srv, false, false)
// error is already UI wrapped
if err != nil {
return err
Expand All @@ -539,9 +548,9 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
return nil
}

func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) {
func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool, startup bool) (cfg *srvtypes.Configuration, err error) {
// do the callbacks to ensure valid profile, location and authorization
err = c.callbacks(ck, srv, forceAuth)
err = c.callbacks(ck, srv, forceAuth, startup)
if err != nil {
return nil, err
}
Expand All @@ -551,7 +560,7 @@ func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAu
return nil, err
}

err = c.profileCallback(ck, srv)
err = c.profileCallback(ck, srv, startup)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -586,7 +595,7 @@ func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Serv
}

// GetConfig gets a VPN configuration
func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool) (cfg *srvtypes.Configuration, err error) {
func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (cfg *srvtypes.Configuration, err error) {
c.mu.Lock()
defer c.mu.Unlock()
previousState := c.FSM.Current
Expand Down Expand Up @@ -626,11 +635,11 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
}

// get a config and retry with authorization if expired
cfg, err = c.config(ck, srv, pTCP, false)
cfg, err = c.config(ck, srv, pTCP, false, startup)
tErr := &oauth.TokensInvalidError{}
if err != nil && errors.As(err, &tErr) {
log.Logger.Debugf("the tokens were invalid, trying again...")
cfg, err = c.config(ck, srv, pTCP, true)
cfg, err = c.config(ck, srv, pTCP, true, startup)
}

// tokens might be updated, forward them
Expand Down Expand Up @@ -886,7 +895,7 @@ func (c *Client) RenewSession(ck *cookie.Cookie) (err error) {
// TODO: Maybe this can be deleted because we force auth now
server.MarkTokensForRenew(srv)
// run the callbacks by forcing auth
return c.callbacks(ck, srv, true)
return c.callbacks(ck, srv, true, false)
}

func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readRxBytes func() (int64, error)) (bool, error) {
Expand Down
74 changes: 67 additions & 7 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestServer(t *testing.T) {
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
_, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
_, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Connect error: %v", configErr)
}
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestTokenExpired(t *testing.T) {
t.Fatalf("Add error: %v", addErr)
}

_, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
_, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)

if configErr != nil {
t.Fatalf("Connect error before expired: %v", configErr)
Expand All @@ -169,7 +169,7 @@ func TestTokenExpired(t *testing.T) {
// Wait for TTL so that the tokens expire
time.Sleep(time.Duration(expiredInt) * time.Second)

_, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
_, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)

if configErr != nil {
t.Fatalf("Connect error after expiry: %v", configErr)
Expand Down Expand Up @@ -215,7 +215,7 @@ func TestInvalidProfileCorrected(t *testing.T) {
t.Fatalf("Add error: %v", addErr)
}

_, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
_, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)

if configErr != nil {
t.Fatalf("First connect error: %v", configErr)
Expand All @@ -234,7 +234,7 @@ func TestInvalidProfileCorrected(t *testing.T) {
previousProfile := base.Profiles.Current
base.Profiles.Current = "IDONOTEXIST"

_, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
_, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Second connect error: %v", configErr)
}
Expand All @@ -248,6 +248,66 @@ func TestInvalidProfileCorrected(t *testing.T) {
}
}

// TestConfigStartup tests if the 'startup' variable for getconfig behaves as expected
func TestConfigStartup(t *testing.T) {
serverURI := getServerURI(t)
ck := cookie.NewWithContext(context.Background())
defer ck.Cancel() //nolint:errcheck
dir := t.TempDir()
state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
dir,
func(old FSMStateID, new FSMStateID, data interface{}) bool {
stateCallback(t, &ck, old, new, data)
return true
},
false,
)
if err != nil {
t.Fatalf("Creating client error: %v", err)
}
err = state.Register()
if err != nil {
t.Fatalf("Failed to register with error: %v", err)
}
// we set true as last argument here such that no callbacks are ran
err = state.AddServer(&ck, serverURI, srvtypes.TypeCustom, true)
if err != nil {
t.Fatalf("Failed to add server for trying config startup: %v", err)
}
testTrue := func() {
// Now get config with setting startup to true
startup := true
_, err := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, startup)
// this should fail as we have not authorized yet/chosen profile and startup=true does not do these callbacks
if err == nil {
t.Fatal("Got no error after getting config with startup true")
}
if !strings.HasPrefix(err.Error(), "The client tried to autoconnect to the VPN server") {
t.Fatalf("GetConfig error for GetConfig with startup=true is not what we expect: %v", err)
}
}
testFalse := func() {
startup := false
// This should succeed
_, err := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, startup)
// this should fail as we have not authorized yet/chosen profile
if err != nil {
t.Fatalf("Got error after getting config with startup=false: %v", err)
}
}
testTrue()
testFalse()

// set invalid authorization and test again
// we cannot test by setting invalid profile because the server only has 1 profile
// TODO: support multiple profiles in the test server
state.Servers.CustomServers.Map[serverURI].OAuth().SetTokenRenew()
testTrue()
testFalse()
}

// Test if prefer tcp is handled correctly by checking the returned config and config type.
func TestPreferTCP(t *testing.T) {
serverURI := getServerURI(t)
Expand Down Expand Up @@ -278,7 +338,7 @@ func TestPreferTCP(t *testing.T) {
}

// get a config with preferTCP set to true
config, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, true)
config, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, true, false)

// Test server should accept prefer TCP!
if config.Protocol != protocol.OpenVPN {
Expand All @@ -295,7 +355,7 @@ func TestPreferTCP(t *testing.T) {
}

// get a config with preferTCP set to false
config, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
config, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Config error: %v", configErr)
}
Expand Down
Loading

0 comments on commit d9e7837

Please sign in to comment.