Skip to content

Commit

Permalink
Merge pull request #170 from wneessen/feature/105_enhance-default-tls…
Browse files Browse the repository at this point in the history
…-port

Add default ports and fallback behavior for SSL and TLS
  • Loading branch information
wneessen authored Jan 31, 2024
2 parents 1c39dc8 + 1efac7f commit 7dced4b
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 4 deletions.
96 changes: 92 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@ import (

// Defaults
const (
// DefaultPort is the default connection port cto the SMTP server
// DefaultPort is the default connection port to the SMTP server
DefaultPort = 25

// DefaultPortSSL is the default connection port for SSL/TLS to the SMTP server
DefaultPortSSL = 465

// DefaultPortTLS is the default connection port for STARTTLS to the SMTP server
DefaultPortTLS = 587

// DefaultTimeout is the default connection timeout
DefaultTimeout = time.Second * 15

Expand Down Expand Up @@ -105,14 +111,15 @@ type Client struct {
// HELO/EHLO string for the greeting the target SMTP server
helo string

// Hostname of the target SMTP server cto connect cto
// Hostname of the target SMTP server to connect to
host string

// pass is the corresponding SMTP AUTH password
pass string

// Port of the SMTP server cto connect cto
port int
// Port of the SMTP server to connect to
port int
fallbackPort int

// sa is a pointer to smtp.Auth
sa smtp.Auth
Expand Down Expand Up @@ -246,13 +253,27 @@ func WithTimeout(t time.Duration) Option {
}

// WithSSL tells the client to use a SSL/TLS connection
//
// Deprecated: use WithSSLPort instead.
func WithSSL() Option {
return func(c *Client) error {
c.ssl = true
return nil
}
}

// WithSSLPort tells the client to use a SSL/TLS connection.
// It automatically sets the port to 465.
//
// When the SSL connection fails and fallback is set to true,
// the client will attempt to connect on port 25 using plaintext.
func WithSSLPort(fb bool) Option {
return func(c *Client) error {
c.SetSSLPort(true, fb)
return nil
}
}

// WithDebugLog tells the client to log incoming and outgoing messages of the SMTP client
// to StdErr
func WithDebugLog() Option {
Expand Down Expand Up @@ -282,13 +303,29 @@ func WithHELO(h string) Option {
}

// WithTLSPolicy tells the client to use the provided TLSPolicy
//
// Deprecated: use WithTLSPortPolicy instead.
func WithTLSPolicy(p TLSPolicy) Option {
return func(c *Client) error {
c.tlspolicy = p
return nil
}
}

// WithTLSPortPolicy tells the client to use the provided TLSPolicy,
// The correct port is automatically set.
//
// Port 587 is used for TLSMandatory and TLSOpportunistic.
// If the connection fails with TLSOpportunistic,
// a plaintext connection is attempted on port 25 as a fallback.
// NoTLS will allways use port 25.
func WithTLSPortPolicy(p TLSPolicy) Option {
return func(c *Client) error {
c.SetTLSPortPolicy(p)
return nil
}
}

// WithTLSConfig tells the client to use the provided *tls.Config
func WithTLSConfig(co *tls.Config) Option {
return func(c *Client) error {
Expand Down Expand Up @@ -430,11 +467,52 @@ func (c *Client) SetTLSPolicy(p TLSPolicy) {
c.tlspolicy = p
}

// SetTLSPortPolicy overrides the current TLSPolicy with the given TLSPolicy
// value. The correct port is automatically set.
//
// Port 587 is used for TLSMandatory and TLSOpportunistic.
// If the connection fails with TLSOpportunistic, a plaintext connection is
// attempted on port 25 as a fallback.
// NoTLS will allways use port 25.
func (c *Client) SetTLSPortPolicy(p TLSPolicy) {
c.port = DefaultPortTLS

if p == TLSOpportunistic {
c.fallbackPort = DefaultPort
}
if p == NoTLS {
c.port = DefaultPort
}

c.tlspolicy = p
}

// SetSSL tells the Client wether to use SSL or not
func (c *Client) SetSSL(s bool) {
c.ssl = s
}

// SetSSLPort tells the Client wether or not to use SSL and fallback.
// The correct port is automatically set.
//
// Port 465 is used when SSL set (true).
// Port 25 is used when SSL is unset (false).
// When the SSL connection fails and fb is set to true,
// the client will attempt to connect on port 25 using plaintext.
func (c *Client) SetSSLPort(ssl bool, fb bool) {
c.port = DefaultPort
if ssl {
c.port = DefaultPortSSL
}

c.fallbackPort = 0
if fb {
c.fallbackPort = DefaultPort
}

c.ssl = ssl
}

// SetDebugLog tells the Client whether debug logging is enabled or not
func (c *Client) SetDebugLog(v bool) {
c.dl = v
Expand Down Expand Up @@ -507,6 +585,10 @@ func (c *Client) DialWithContext(pc context.Context) error {
}
var err error
c.co, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
c.co, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return err
}
Expand Down Expand Up @@ -606,6 +688,12 @@ func (c *Client) checkConn() error {
return nil
}

// serverFallbackAddr returns the currently set combination of hostname
// and fallback port.
func (c *Client) serverFallbackAddr() string {
return fmt.Sprintf("%s:%d", c.host, c.fallbackPort)
}

// tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error {
if c.co == nil {
Expand Down
107 changes: 107 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,12 @@ func TestNewClientWithOptions(t *testing.T) {
{"WithTimeout()", WithTimeout(time.Second * 5), false},
{"WithTimeout()", WithTimeout(-10), true},
{"WithSSL()", WithSSL(), false},
{"WithSSLPort(false)", WithSSLPort(false), false},
{"WithSSLPort(true)", WithSSLPort(true), false},
{"WithHELO()", WithHELO(host), false},
{"WithHELO(); helo is empty", WithHELO(""), true},
{"WithTLSPolicy()", WithTLSPolicy(TLSOpportunistic), false},
{"WithTLSPortPolicy()", WithTLSPortPolicy(TLSOpportunistic), false},
{"WithTLSConfig()", WithTLSConfig(&tls.Config{}), false},
{"WithTLSConfig(); config is nil", WithTLSConfig(nil), true},
{"WithSMTPAuth()", WithSMTPAuth(SMTPAuthLogin), false},
Expand Down Expand Up @@ -235,6 +238,42 @@ func TestWithTLSPolicy(t *testing.T) {
}
}

// TestWithTLSPortPolicy tests the WithTLSPortPolicy() option for the NewClient() method
func TestWithTLSPortPolicy(t *testing.T) {
tests := []struct {
name string
value TLSPolicy
want string
wantPort int
fbPort int
sf bool
}{
{"Policy: TLSMandatory", TLSMandatory, TLSMandatory.String(), 587, 0, false},
{"Policy: TLSOpportunistic", TLSOpportunistic, TLSOpportunistic.String(), 587, 25, false},
{"Policy: NoTLS", NoTLS, NoTLS.String(), 25, 0, false},
{"Policy: Invalid", -1, "UnknownPolicy", 587, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(DefaultHost, WithTLSPortPolicy(tt.value))
if err != nil && !tt.sf {
t.Errorf("failed to create new client: %s", err)
return
}
if c.tlspolicy.String() != tt.want {
t.Errorf("failed to set TLSPortPolicy. Want: %s, got: %s", tt.want, c.tlspolicy)
}
if c.port != tt.wantPort {
t.Errorf("failed to set TLSPortPolicy, wanted port: %d, got: %d", tt.wantPort, c.port)
}
if c.fallbackPort != tt.fbPort {
t.Errorf("failed to set TLSPortPolicy, wanted fallbakc port: %d, got: %d", tt.fbPort,
c.fallbackPort)
}
})
}
}

// TestSetTLSPolicy tests the SetTLSPolicy() method for the Client object
func TestSetTLSPolicy(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -312,6 +351,42 @@ func TestSetSSL(t *testing.T) {
}
}

// TestSetSSLPort tests the Client.SetSSLPort method
func TestClient_SetSSLPort(t *testing.T) {
tests := []struct {
name string
value bool
fb bool
port int
fbPort int
}{
{"SSL: on, fb: off", true, false, 465, 0},
{"SSL: on, fb: on", true, true, 465, 25},
{"SSL: off, fb: off", false, false, 25, 0},
{"SSL: off, fb: on", false, true, 25, 25},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(DefaultHost)
if err != nil {
t.Errorf("failed to create new client: %s", err)
return
}
c.SetSSLPort(tt.value, tt.fb)
if c.ssl != tt.value {
t.Errorf("failed to set SSL setting. Got: %t, want: %t", c.ssl, tt.value)
}
if c.port != tt.port {
t.Errorf("failed to set SSLPort, wanted port: %d, got: %d", c.port, tt.port)
}
if c.fallbackPort != tt.fbPort {
t.Errorf("failed to set SSLPort, wanted fallback port: %d, got: %d", c.fallbackPort,
tt.fbPort)
}
})
}
}

// TestSetUsername tests the SetUsername method for the Client object
func TestSetUsername(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -550,6 +625,38 @@ func TestClient_DialWithContext(t *testing.T) {
}
}

// TestClient_DialWithContext_Fallback tests the Client.DialWithContext method with the fallback
// port functionality
func TestClient_DialWithContext_Fallback(t *testing.T) {
c, err := getTestConnection(true)
if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err)
}
c.SetTLSPortPolicy(TLSOpportunistic)
c.port = 999
ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err)
return
}
if c.co == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.sc == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
}
if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}

c.port = 999
c.fallbackPort = 999
if err = c.DialWithContext(ctx); err == nil {
t.Error("dial with context was supposed to fail, but didn't")
return
}
}

// TestClient_DialWithContext_Debug tests the DialWithContext method for the Client object with debug
// logging enabled on the SMTP client
func TestClient_DialWithContext_Debug(t *testing.T) {
Expand Down

0 comments on commit 7dced4b

Please sign in to comment.