-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
utils/connfix: utility to simplify wrapping net.Conn without butcheri…
…ng its full interface When embedding net.Conn (or any interface) the Go compiler looses track of methods provided by implementations other than what the interface specifies. Sadly we cannot use generics to solve that. It would be best to have type FooConn[T net.Conn] struct { T } BUT the compiler claims "Embedded type cannot be a type parameter". This package fixes that for a very narrow, but important, usecase of overriding any of the following functions: - ReadFrom(r io.Reader) (int64, error) - WriteTo(w io.Writer) (int64, error) { - CloseWrite() Thanks to that the outer net.Conn can provide a single implementation and persist the original interface.
- Loading branch information
Showing
2 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// Copyright 2022-2024 Sauce Labs Inc., all rights reserved. | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
package connfix | ||
|
||
import ( | ||
"io" | ||
"net" | ||
) | ||
|
||
type ( | ||
readFromMixin struct{ net.Conn } | ||
writeToMixin struct{ net.Conn } | ||
closeWriteMixin struct{ net.Conn } | ||
) | ||
|
||
func (rf readFromMixin) ReadFrom(r io.Reader) (int64, error) { | ||
return rf.Conn.(io.ReaderFrom).ReadFrom(r) //nolint:forcetypeassert // we know the type is correct | ||
} | ||
|
||
var _ io.ReaderFrom = readFromMixin{} | ||
|
||
func (wt writeToMixin) WriteTo(w io.Writer) (int64, error) { | ||
return wt.Conn.(io.WriterTo).WriteTo(w) //nolint:forcetypeassert // we know the type is correct | ||
} | ||
|
||
var _ io.WriterTo = writeToMixin{} | ||
|
||
type _closeWriter interface { | ||
CloseWrite() error | ||
} | ||
|
||
func (cw closeWriteMixin) CloseWrite() error { | ||
return cw.Conn.(_closeWriter).CloseWrite() //nolint:forcetypeassert // we know the type is correct | ||
} | ||
|
||
var _ _closeWriter = closeWriteMixin{} | ||
|
||
const ( | ||
readerFrom = 1 << iota | ||
writerTo | ||
closeWriter | ||
) | ||
|
||
func flags(conn net.Conn) uint8 { | ||
var f uint8 | ||
if _, ok := conn.(io.ReaderFrom); ok { | ||
f |= readerFrom | ||
} | ||
if _, ok := conn.(io.WriterTo); ok { | ||
f |= writerTo | ||
} | ||
if _, ok := conn.(_closeWriter); ok { | ||
f |= closeWriter | ||
} | ||
return f | ||
} | ||
|
||
// Combine returns a net.Conn that combines the functionality of the outer and inner net.Conn. | ||
// It detects if the inner net.Conn provides any of the following functions: | ||
// - ReadFrom(r io.Reader) (int64, error) | ||
// - WriteTo(w io.Writer) (int64, error) { | ||
// - CloseWrite() | ||
// | ||
// and returns a net.Conn that implements the same interfaces. | ||
// | ||
// The outer net.Conn may also provide these functions, | ||
// they are used only if the inner net.Conn also provides them. | ||
// This allows the implementors of the outer net.Conn to provide implementations that are used when possible. | ||
func Combine(outer, inner net.Conn) net.Conn { | ||
readFromMixin := func() readFromMixin { | ||
if _, ok := outer.(io.ReaderFrom); ok { | ||
return readFromMixin{outer} | ||
} | ||
return readFromMixin{inner} | ||
} | ||
writeToMixin := func() writeToMixin { | ||
if _, ok := outer.(io.WriterTo); ok { | ||
return writeToMixin{outer} | ||
} | ||
return writeToMixin{inner} | ||
} | ||
closeWriteMixin := func() closeWriteMixin { | ||
if _, ok := outer.(_closeWriter); ok { | ||
return closeWriteMixin{outer} | ||
} | ||
return closeWriteMixin{inner} | ||
} | ||
|
||
switch flags(inner) { | ||
case 0: | ||
return struct { | ||
net.Conn | ||
}{outer} | ||
case readerFrom: | ||
return struct { | ||
net.Conn | ||
io.ReaderFrom | ||
}{outer, readFromMixin()} | ||
case writerTo: | ||
return struct { | ||
net.Conn | ||
io.WriterTo | ||
}{outer, writeToMixin()} | ||
case closeWriter: | ||
return struct { | ||
net.Conn | ||
_closeWriter | ||
}{outer, closeWriteMixin()} | ||
case readerFrom | writerTo: | ||
return struct { | ||
net.Conn | ||
io.ReaderFrom | ||
io.WriterTo | ||
}{outer, readFromMixin(), writeToMixin()} | ||
case readerFrom | closeWriter: | ||
return struct { | ||
net.Conn | ||
io.ReaderFrom | ||
_closeWriter | ||
}{outer, readFromMixin(), closeWriteMixin()} | ||
case writerTo | closeWriter: | ||
return struct { | ||
net.Conn | ||
io.WriterTo | ||
_closeWriter | ||
}{outer, writeToMixin(), closeWriteMixin()} | ||
case readerFrom | writerTo | closeWriter: | ||
return struct { | ||
net.Conn | ||
io.ReaderFrom | ||
io.WriterTo | ||
_closeWriter | ||
}{outer, readFromMixin(), writeToMixin(), closeWriteMixin()} | ||
default: | ||
panic("unreachable") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright 2022-2024 Sauce Labs Inc., all rights reserved. | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
package connfix | ||
|
||
import ( | ||
"bytes" | ||
"io" | ||
"net" | ||
"testing" | ||
) | ||
|
||
type testConn struct { | ||
net.Conn | ||
} | ||
|
||
var testAddr = &net.TCPAddr{} | ||
|
||
func (tc testConn) RemoteAddr() net.Addr { | ||
return testAddr | ||
} | ||
|
||
func (tc testConn) WriteTo(w io.Writer) (int64, error) { | ||
n, err := w.Write([]byte("test")) | ||
return int64(n), err | ||
} | ||
|
||
func TestCombineTCPConn(t *testing.T) { | ||
tconn := new(net.TCPConn) | ||
if flags(tconn) == 0 { | ||
t.Fatal("flags(tconn) == 0") | ||
} | ||
|
||
t.Run("basic", func(t *testing.T) { | ||
conn := Combine(testConn{tconn}, tconn) | ||
if flags(conn) != flags(tconn) { | ||
t.Fatal("flags(conn) != flags(tconn)") | ||
} | ||
if conn.RemoteAddr() != testAddr { | ||
t.Fatal("conn.RemoteAddr() != testAddr") | ||
} | ||
}) | ||
|
||
t.Run("overwrite", func(t *testing.T) { | ||
conn := Combine(testConn{tconn}, tconn) | ||
if flags(conn) != flags(tconn) { | ||
t.Fatal("flags(conn) != flags(tconn)") | ||
} | ||
var buf bytes.Buffer | ||
if _, err := conn.(io.WriterTo).WriteTo(&buf); err != nil { | ||
t.Fatal(err) | ||
} | ||
if buf.String() != "test" { | ||
t.Fatal("expected 'test', got", buf.String()) | ||
} | ||
}) | ||
|
||
t.Run("no overwrite", func(t *testing.T) { | ||
conn := Combine(testConn{tconn}, nil) | ||
if _, ok := conn.(io.WriterTo); ok { | ||
t.Fatal("expected no io.WriterTo") | ||
} | ||
}) | ||
} |