Skip to content

Commit

Permalink
Testability changes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
metafeather committed May 4, 2024
1 parent dbc2da5 commit a3b42fc
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 42 deletions.
111 changes: 69 additions & 42 deletions modules/l4postgres/matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,48 @@ func newMessageFromConn(cx *layer4.Connection) (*message, error) {
// Get bytes containing the message length
head := make([]byte, initMessageSizeLength)
if _, err := io.ReadFull(cx, head); err != nil {
return &message{}, err
return nil, err
}

// Get actual message length
data := make([]byte, binary.BigEndian.Uint32(head)-initMessageSizeLength)
if _, err := io.ReadFull(cx, data); err != nil {
return &message{}, err
return nil, err
}

cx.Logger.Debug("layer4.matchers.postgres.message",
zap.Int("len(head)", len(head)),
zap.Int("len(data)", len(data)),
)

return newMessageFromBytes(data), nil
}

// StartupMessage contains the values parsed from the startup message
// StartupMessage contains the values parsed from the first message received.
// This should be either a SSLRequest or StartupMessage
type startupMessage struct {
ProtocolVersion uint32
Parameters map[string]string
}

// IsSSL confirms this is a SSLRequest
func (s startupMessage) IsSSL() bool {
return isSSLRequest(s.ProtocolVersion)
}

// IsSupported confirms this is a supported version of Postgres
func (s startupMessage) IsSupported() bool {
return isSupported(s.ProtocolVersion)
}

// NewStartupMessage creates a new startupMessage from the message bytes
func newStartupMessage(b *message) *startupMessage {
return &startupMessage{
ProtocolVersion: b.ReadUint32(),
Parameters: parseParameters(b),
}
}

// MatchPostgres is able to match Postgres connections
type MatchPostgres struct {
Users map[string][]string
Expand All @@ -147,36 +171,26 @@ func (m MatchPostgres) Match(cx *layer4.Connection) (bool, error) {
return false, err
}

code := b.ReadUint32()
hasConfig := len(m.Users) == 0
m.startup = newStartupMessage(b)
hasConfig := len(m.Users) > 0

cx.Logger.Debug("layer4.matchers.postgres",
zap.String("matcher", fmt.Sprintf("%#v", m)),
zap.String("message", fmt.Sprintf("%#v", m.startup)),
)

// Finish if this is a SSLRequest and there are no other matchers
if code == sslRequestCode && !hasConfig {
if m.startup.IsSSL() && !hasConfig {
return true, nil
}

// Check supported protocol
if majorVersion := code >> 16; majorVersion < 3 {
if !m.startup.IsSupported() {
return false, errors.New("pg protocol < 3.0 is not supported")
}

// Try parsing Postgres Params
m.startup = &startupMessage{ProtocolVersion: code, Parameters: make(map[string]string)}
for {
k := b.ReadString()
if k == "" {
break
}
m.startup.Parameters[k] = b.ReadString()
}

cx.Logger.Debug("layer4.matchers.postgres",
zap.String("match.config", fmt.Sprintf("%v", m.Users)),
zap.String("startupMessage", fmt.Sprintf("%v", m.startup.Parameters)),
)

// Finish if no more matchers are configured
if hasConfig {
if !hasConfig {
return true, nil
}

Expand Down Expand Up @@ -230,33 +244,22 @@ func (m MatchPostgresClients) Match(cx *layer4.Connection) (bool, error) {
return false, err
}

code := b.ReadUint32()
m.startup = newStartupMessage(b)
cx.Logger.Debug("layer4.matchers.postgres_client",
zap.String("matcher", fmt.Sprintf("%#v", m)),
zap.String("message", fmt.Sprintf("%#v", m.startup)),
)

// Reject if this is a SSLRequest as it has no params
if code == sslRequestCode {
if m.startup.IsSSL() {
return false, nil
}

// Check supported protocol
if majorVersion := code >> 16; majorVersion < 3 {
if !m.startup.IsSupported() {
return false, errors.New("pg protocol < 3.0 is not supported")
}

// Try parsing Postgres Params
m.startup = &startupMessage{ProtocolVersion: code, Parameters: make(map[string]string)}
for {
k := b.ReadString()
if k == "" {
break
}
m.startup.Parameters[k] = b.ReadString()
}

cx.Logger.Debug("layer4.matchers.postgres_client",
zap.String("match.config", fmt.Sprintf("%v", m.Clients)),
zap.String("startupMessage", fmt.Sprintf("%v", m.startup.Parameters)),
)

// Is there a application_name to check?
name, ok := m.startup.Parameters["application_name"]
if !ok {
Expand Down Expand Up @@ -288,15 +291,39 @@ func (m MatchPostgresSSL) Match(cx *layer4.Connection) (bool, error) {
}

code := b.ReadUint32()
cx.Logger.Debug("layer4.matchers.postgres_ssl",
zap.String("matcher", fmt.Sprintf("%#v", m)),
)

// SSLRequest Message required?
if code == sslRequestCode {
if isSSLRequest(code) {
return m.Required, nil
}

return false, nil
}

func isSSLRequest(code uint32) bool {
return code == sslRequestCode
}

func isSupported(code uint32) bool {
majorVersion := code >> 16
return !(majorVersion > 3)
}

func parseParameters(b *message) map[string]string {
params := make(map[string]string)
for {
k := b.ReadString()
if k == "" {
break
}
params[k] = b.ReadString()
}
return params
}

// Interface guard
var _ layer4.ConnMatcher = (*MatchPostgres)(nil)
var _ layer4.ConnMatcher = (*MatchPostgresClients)(nil)
Expand Down
100 changes: 100 additions & 0 deletions modules/l4postgres/matcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package l4postgres

import (
"bytes"
"encoding/binary"
"io"
"net"
"sync"
"testing"

"github.com/mholt/caddy-l4/layer4"
"go.uber.org/zap"
)

// Example extends Message with utils to create example messages.
// ref: https://github.com/rueian/pgbroker/blob/master/message/util.go
type example struct {
message
}

func (b *example) WriteByte(i byte) {
b.data[b.offset] = i
b.offset++
}

func (b *example) WriteByteN(i []byte) {
for _, s := range i {
b.WriteByte(s)
}
}

func (b *example) WriteUint32(i uint32) {
binary.BigEndian.PutUint32(b.data[b.offset:b.offset+4], i)
b.offset += 4
}

func (b *example) WriteString(i string) {
b.WriteByteN([]byte(i))
b.WriteByte(0)
}

var ExampleSSLRequest = func() []byte {
x := &example{message: message{data: make([]byte, 8)}}
x.WriteUint32(8)
x.WriteUint32(sslRequestCode)
return x.data
}

var ExampleStartupMessage = func() []byte {
x := &example{message: message{data: make([]byte, 8)}}
x.WriteUint32(8)
x.WriteUint32(sslRequestCode)
x.WriteString("user=postgres")
x.offset = 0
return x.data
}

func assertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("Unexpected error: %s\n", err)
}
}

func closePipe(wg *sync.WaitGroup, c1 net.Conn, c2 net.Conn) {
wg.Wait()
_ = c1.Close()
_ = c2.Close()
}

func TestPostgresSSLMatch(t *testing.T) {
wg := &sync.WaitGroup{}
in, out := net.Pipe()
defer closePipe(wg, in, out)

cx := layer4.WrapConnection(in, &bytes.Buffer{}, zap.NewNop())

x := ExampleSSLRequest()

wg.Add(1)
go func() {
defer wg.Done()
defer out.Close()
_, err := out.Write(x)
assertNoError(t, err)
}()

matcher := MatchPostgresSSL{
Required: true,
}

matched, err := matcher.Match(cx)
assertNoError(t, err)

if !matched {
t.Fatalf("matcher did not match SSL")
}

_, _ = io.Copy(io.Discard, in)
}

0 comments on commit a3b42fc

Please sign in to comment.