Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add linking tests #142

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions cmd/link_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package cmd

import (
"log"
"testing"
"time"

"github.com/charmbracelet/charm/client"
"github.com/charmbracelet/charm/proto"
"github.com/charmbracelet/charm/server"
"github.com/charmbracelet/charm/testserver"
)

// TestValidLinkGen
func TestValidLinkGen(t *testing.T) {
client1 := testserver.SetupTestServer(t)
client2, err := client.NewClientWithDefaults()
if err != nil {
t.Fatalf("error creating second client: %v", err)
}
lc := make(chan string, 1)
t.Run("link client 1", func(t *testing.T) {
t.Parallel()
lh := &linkHandlerTest{desc: "client1", linkChan: lc, approve: true}
// pass testing.t to it, assert error
err := client1.LinkGen(lh)
if err != nil {
t.Fatalf("failed to link client 1: %v", err)
}
})
t.Run("link client 2", func(t *testing.T) {
t.Parallel()
tok := <-lc
lh := &linkHandlerTest{desc: "client2", linkChan: lc}
err = client2.Link(lh, tok)
if err != nil {
t.Fatalf("failed to link client 2: %v", err)
}
})
}

// TestInvalidLinkGen
func TestInvalidLinkGen(t *testing.T) {
client1 := testserver.SetupTestServer(t)
client2, err := client.NewClientWithDefaults()
if err != nil {
t.Fatalf("error creating second client: %v", err)
}
lc := make(chan string, 1)
t.Run("link client 1", func(t *testing.T) {
t.Parallel()
lh := &linkHandlerTest{desc: "client1", linkChan: lc, approve: false}
// pass testing.t to it, assert error
err := client1.LinkGen(lh)
if err != nil {
t.Fatalf("failed to link client 1: %v", err)
}
})
t.Run("link client 2", func(t *testing.T) {
t.Parallel()
tok := <-lc
lh := &linkHandlerTest{desc: "client2", linkChan: lc}
err = client2.Link(lh, tok)
if err != nil {
t.Fatalf("failed to link client 2: %v", err)
}
if lh.status != requestDenied {
t.Fatalf("expected request denied, got: %d", lh.status)
}
})
}

// TestTimeoutLink
func TestTimeoutLink(t *testing.T) {
client1 := testserver.SetupTestServer(t, func(c *server.Config) *server.Config {
return c.WithLinkTimeout(5 * time.Second)
})
lc := make(chan string, 1)
t.Run("link client 1", func(t *testing.T) {
t.Parallel()
lh := &linkHandlerTest{desc: "client1", linkChan: lc, approve: true}
// pass testing.t to it, assert error
err := client1.LinkGen(lh)
if err != nil {
t.Fatalf("failed to link client 1: %v", err)
}
if lh.status != timedout {
t.Fatalf("expected link to timeout, got: %v", lh.status)
}
})
}

// use these status codes for assertions in tests
type statusCode uint

const (
ok statusCode = iota
timedout
invalidToken
requestDenied
)

type linkHandlerTest struct {
desc string
linkChan chan string
approve bool
status statusCode
}

func (lh *linkHandlerTest) TokenCreated(l *proto.Link) {
lh.printDebug("token created", l)
lh.linkChan <- string(l.Token)
lh.printDebug("token created sent to chan", l)
}

func (lh *linkHandlerTest) TokenSent(l *proto.Link) {
lh.printDebug("token sent", l)
}

func (lh *linkHandlerTest) ValidToken(l *proto.Link) {
lh.printDebug("valid token", l)
}

func (lh *linkHandlerTest) InvalidToken(l *proto.Link) {
lh.status = invalidToken
}

func (lh *linkHandlerTest) Request(l *proto.Link) bool {
return lh.approve
}

func (lh *linkHandlerTest) RequestDenied(l *proto.Link) {
lh.status = requestDenied
}

func (lh *linkHandlerTest) SameUser(l *proto.Link) {
lh.printDebug("same user", l)
}

func (lh *linkHandlerTest) Success(l *proto.Link) {
lh.printDebug("success", l)
lh.status = ok
}

func (lh *linkHandlerTest) Timeout(l *proto.Link) {
lh.status = timedout
}

func (lh linkHandlerTest) Error(l *proto.Link) {
lh.printDebug("error", l)
}

func (lh *linkHandlerTest) printDebug(msg string, l *proto.Link) {
log.Printf("%s %s:\t%v\n", lh.desc, msg, l)
}
4 changes: 2 additions & 2 deletions proto/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ const (
LinkStatusInvalidTokenRequest
)

// LinkTimeout is the length of time a Token is valid for.
const LinkTimeout = time.Minute
// DefaultLinkTimeout is the length of time a Token is valid for.
const DefaultLinkTimeout = time.Minute

// Token represent the confirmation code generated during linking.
type Token string
Expand Down
16 changes: 9 additions & 7 deletions server/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ func (sl *SSHLinker) User() *charm.User {
return sl.account
}

// LinkGen implements the proto.LinkTransport interface for the SSHLinker.
// LinkGen generates a link token and sends it to the user using the given
// link transport.
func (me *SSHServer) LinkGen(lt charm.LinkTransport) error {
u := lt.User()
tok := me.NewToken()
Expand All @@ -136,7 +137,7 @@ func (me *SSHServer) LinkGen(lt charm.LinkTransport) error {
}()
select {
case <-ch:
case <-time.After(charm.LinkTimeout):
case <-time.After(me.linkTimeout):
log.Printf("Link %s timed out", tok)
l.Status = charm.LinkStatusTimedOut
lt.TimedOut(l)
Expand Down Expand Up @@ -201,14 +202,15 @@ func (me *SSHServer) LinkGen(lt charm.LinkTransport) error {
l.Status = charm.LinkStatusRequestDenied
}
me.linkQueue.SendLinkRequest(lt, linkRequest, l)
case <-time.After(charm.LinkTimeout):
case <-time.After(me.linkTimeout):
log.Printf("Link %s timed out", tok)
lt.TimedOut(&charm.Link{Token: tok, Status: charm.LinkStatusTimedOut})
}
return nil
}

// LinkRequest implements the proto.LinkTransport interface for the SSHLinker.
// LinkRequest links a new machine to the link transport user account after
// validating the token.
func (me *SSHServer) LinkRequest(lt charm.LinkTransport, key string, token string, ip string) error {
l := &charm.Link{
Host: me.config.Host,
Expand Down Expand Up @@ -243,11 +245,11 @@ func (me *SSHServer) LinkRequest(lt charm.LinkTransport, key string, token strin
l.Status = charm.LinkStatusError
lt.Error(l)
}
case <-time.After(charm.LinkTimeout):
case <-time.After(me.linkTimeout):
l.Status = charm.LinkStatusTimedOut
lt.TimedOut(l)
}
case <-time.After(charm.LinkTimeout):
case <-time.After(me.linkTimeout):
l.Status = charm.LinkStatusTimedOut
lt.TimedOut(l)
}
Expand Down Expand Up @@ -394,7 +396,7 @@ func (s *channelLinkQueue) SendLinkRequest(lt charm.LinkTransport, lc chan *char
go func() {
select {
case lc <- l:
case <-time.After(charm.LinkTimeout):
case <-time.After(s.s.linkTimeout):
l.Status = charm.LinkStatusTimedOut
lt.TimedOut(l)
}
Expand Down
8 changes: 8 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log"
"net/url"
"path/filepath"
"time"

"github.com/caarlos0/env/v6"
charm "github.com/charmbracelet/charm/proto"
Expand Down Expand Up @@ -45,6 +46,7 @@ type Config struct {
FileStore storage.FileStore
Stats stats.Stats
linkQueue charm.LinkQueue
linkTimeout time.Duration
tlsConfig *tls.Config
jwtKeyPair JSONWebKeyPair
httpScheme string
Expand Down Expand Up @@ -112,6 +114,12 @@ func (cfg *Config) WithLinkQueue(q charm.LinkQueue) *Config {
return cfg
}

// WithLinkTimeout returns a Config with the provided link timeout.
func (cfg *Config) WithLinkTimeout(dur time.Duration) *Config {
cfg.linkTimeout = dur
return cfg
}

func (cfg *Config) httpURL() *url.URL {
s := fmt.Sprintf("%s://%s:%d", cfg.httpScheme, cfg.Host, cfg.HTTPPort)
if cfg.PublicURL != "" {
Expand Down
21 changes: 13 additions & 8 deletions server/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@ type SessionHandler func(s Session)
// SSHServer serves the SSH protocol and handles requests to authenticate and
// link Charm user accounts.
type SSHServer struct {
config *Config
db db.DB
server *ssh.Server
errorLog *log.Logger
linkQueue charm.LinkQueue
config *Config
db db.DB
server *ssh.Server
errorLog *log.Logger
linkQueue charm.LinkQueue
linkTimeout time.Duration
}

// NewSSHServer creates a new SSHServer from the provided Config.
func NewSSHServer(cfg *Config) (*SSHServer, error) {
s := &SSHServer{
config: cfg,
errorLog: cfg.errorLog,
linkQueue: cfg.linkQueue,
config: cfg,
errorLog: cfg.errorLog,
linkQueue: cfg.linkQueue,
linkTimeout: cfg.linkTimeout,
}
if s.linkTimeout == 0 {
s.linkTimeout = charm.DefaultLinkTimeout
}
if s.errorLog == nil {
s.errorLog = log.Default()
Expand Down
10 changes: 9 additions & 1 deletion testserver/testserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ import (
"github.com/charmbracelet/keygen"
)

// Option is a function that can be used to modify the server configuration.
type Option func(*server.Config) *server.Config

// SetupTestServer starts a test server and sets the needed environment
// variables so clients pick it up.
// It also returns a client forcing these settings in.
// Unless you use the given client, this is not really thread safe due
// to setting a bunch of environment variables.
func SetupTestServer(tb testing.TB) *client.Client {
func SetupTestServer(tb testing.TB, opts ...Option) *client.Client {
tb.Helper()

td := tb.TempDir()
Expand All @@ -40,6 +43,11 @@ func SetupTestServer(tb testing.TB) *client.Client {
}

cfg = cfg.WithKeys(kp.PublicKey(), kp.PrivateKeyPEM())

for _, opt := range opts {
cfg = opt(cfg)
}

s, err := server.NewServer(cfg)
if err != nil {
tb.Fatalf("new server error: %s", err)
Expand Down