Skip to content

Commit

Permalink
utils/connfix: utility to simplify wrapping net.Conn without butcheri…
Browse files Browse the repository at this point in the history
…ng its full interface

When embedding net.Conn (or any interface) the Go compiler looses track of methods provided by implementations other than what the interface specifies.

Sadly we cannot use generics to solve that. It would be best to have

type FooConn[T net.Conn] struct {
    T
}

BUT the compiler claims "Embedded type cannot be a type parameter".

This package fixes that for a very narrow, but important, usecase of overriding any of the following functions:
 - ReadFrom(r io.Reader) (int64, error)
 - WriteTo(w io.Writer) (int64, error) {
 - CloseWrite()

Thanks to that the outer net.Conn can provide a single implementation and persist the original interface.
  • Loading branch information
mmatczuk committed Oct 8, 2024
1 parent 0ffc990 commit fbf2023
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 0 deletions.
141 changes: 141 additions & 0 deletions utils/connfix/connfix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2022-2024 Sauce Labs Inc., all rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package connfix

import (
"io"
"net"
)

type (
readFromMixin struct{ net.Conn }
writeToMixin struct{ net.Conn }
closeWriteMixin struct{ net.Conn }
)

func (rf readFromMixin) ReadFrom(r io.Reader) (int64, error) {
return rf.Conn.(io.ReaderFrom).ReadFrom(r) //nolint:forcetypeassert // we know the type is correct
}

var _ io.ReaderFrom = readFromMixin{}

func (wt writeToMixin) WriteTo(w io.Writer) (int64, error) {
return wt.Conn.(io.WriterTo).WriteTo(w) //nolint:forcetypeassert // we know the type is correct
}

var _ io.WriterTo = writeToMixin{}

type _closeWriter interface {
CloseWrite() error
}

func (cw closeWriteMixin) CloseWrite() error {
return cw.Conn.(_closeWriter).CloseWrite() //nolint:forcetypeassert // we know the type is correct
}

var _ _closeWriter = closeWriteMixin{}

const (
readerFrom = 1 << iota
writerTo
closeWriter
)

func flags(conn net.Conn) uint8 {
var f uint8
if _, ok := conn.(io.ReaderFrom); ok {
f |= readerFrom
}
if _, ok := conn.(io.WriterTo); ok {
f |= writerTo
}
if _, ok := conn.(_closeWriter); ok {
f |= closeWriter
}
return f
}

// Combine returns a net.Conn that combines the functionality of the outer and inner net.Conn.
// It detects if the inner net.Conn provides any of the following functions:
// - ReadFrom(r io.Reader) (int64, error)
// - WriteTo(w io.Writer) (int64, error) {
// - CloseWrite()
//
// and returns a net.Conn that implements the same interfaces.
//
// The outer net.Conn may also provide these functions,
// they are used only if the inner net.Conn also provides them.
// This allows the implementors of the outer net.Conn to provide implementations that are used when possible.
func Combine(outer, inner net.Conn) net.Conn {
readFromMixin := func() readFromMixin {
if _, ok := outer.(io.ReaderFrom); ok {
return readFromMixin{outer}
}
return readFromMixin{inner}
}
writeToMixin := func() writeToMixin {
if _, ok := outer.(io.WriterTo); ok {
return writeToMixin{outer}
}
return writeToMixin{inner}
}
closeWriteMixin := func() closeWriteMixin {
if _, ok := outer.(_closeWriter); ok {
return closeWriteMixin{outer}
}
return closeWriteMixin{inner}
}

switch flags(inner) {
case 0:
return struct {
net.Conn
}{outer}
case readerFrom:
return struct {
net.Conn
io.ReaderFrom
}{outer, readFromMixin()}
case writerTo:
return struct {
net.Conn
io.WriterTo
}{outer, writeToMixin()}
case closeWriter:
return struct {
net.Conn
_closeWriter
}{outer, closeWriteMixin()}
case readerFrom | writerTo:
return struct {
net.Conn
io.ReaderFrom
io.WriterTo
}{outer, readFromMixin(), writeToMixin()}
case readerFrom | closeWriter:
return struct {
net.Conn
io.ReaderFrom
_closeWriter
}{outer, readFromMixin(), closeWriteMixin()}
case writerTo | closeWriter:
return struct {
net.Conn
io.WriterTo
_closeWriter
}{outer, writeToMixin(), closeWriteMixin()}
case readerFrom | writerTo | closeWriter:
return struct {
net.Conn
io.ReaderFrom
io.WriterTo
_closeWriter
}{outer, readFromMixin(), writeToMixin(), closeWriteMixin()}
default:
panic("unreachable")
}
}
67 changes: 67 additions & 0 deletions utils/connfix/connfix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2022-2024 Sauce Labs Inc., all rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package connfix

import (
"bytes"
"io"
"net"
"testing"
)

type testConn struct {
net.Conn
}

var testAddr = &net.TCPAddr{}

func (tc testConn) RemoteAddr() net.Addr {
return testAddr
}

func (tc testConn) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte("test"))
return int64(n), err
}

func TestCombineTCPConn(t *testing.T) {
tconn := new(net.TCPConn)
if flags(tconn) == 0 {
t.Fatal("flags(tconn) == 0")
}

t.Run("basic", func(t *testing.T) {
conn := Combine(testConn{tconn}, tconn)
if flags(conn) != flags(tconn) {
t.Fatal("flags(conn) != flags(tconn)")
}
if conn.RemoteAddr() != testAddr {
t.Fatal("conn.RemoteAddr() != testAddr")
}
})

t.Run("overwrite", func(t *testing.T) {
conn := Combine(testConn{tconn}, tconn)
if flags(conn) != flags(tconn) {
t.Fatal("flags(conn) != flags(tconn)")
}
var buf bytes.Buffer
if _, err := conn.(io.WriterTo).WriteTo(&buf); err != nil {
t.Fatal(err)
}
if buf.String() != "test" {
t.Fatal("expected 'test', got", buf.String())
}
})

t.Run("no overwrite", func(t *testing.T) {
conn := Combine(testConn{tconn}, nil)
if _, ok := conn.(io.WriterTo); ok {
t.Fatal("expected no io.WriterTo")
}
})
}

0 comments on commit fbf2023

Please sign in to comment.