diff --git a/.gitignore b/.gitignore index 018b0ec..40f36d3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .vscode cmd/veil/veil cmd/veil-verify/veil-verify +cmd/veil-proxy/veil-proxy cover.html cover.out *.tar diff --git a/Makefile b/Makefile index 39d2f21..2bdea5d 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,8 @@ prog = veil prog_dir = cmd/veil verify_prog = veil-verify verify_prog_dir = cmd/veil-verify +proxy_prog = veil-proxy +proxy_prog_dir = cmd/veil-proxy godeps = go.mod go.sum $(shell find cmd internal -name "*.go" -type f) image_tag := $(prog) @@ -89,6 +91,10 @@ $(verify_prog): $(godeps) @go build -C $(verify_prog_dir) -o $(verify_prog) @-sha1sum "$(verify_prog_dir)/$(verify_prog)" +$(proxy_prog): $(godeps) + @go build -C $(proxy_prog_dir) -o $(proxy_prog) + @-sha1sum "$(proxy_prog_dir)/$(proxy_prog)" + .PHONY: clean clean: rm -f $(prog_dir)/$(prog) diff --git a/cmd/veil-proxy/main.go b/cmd/veil-proxy/main.go new file mode 100644 index 0000000..cd9c34e --- /dev/null +++ b/cmd/veil-proxy/main.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "flag" + "io" + "log" + "net" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "sync" + + "github.com/Amnesic-Systems/veil/internal/errs" + "github.com/Amnesic-Systems/veil/internal/net/nat" + "github.com/Amnesic-Systems/veil/internal/net/proxy" + "github.com/Amnesic-Systems/veil/internal/net/tun" + "github.com/mdlayher/vsock" +) + +type config struct { + profile bool + port int +} + +func parseFlags(out io.Writer, args []string) (_ *config, err error) { + defer errs.Wrap(&err, "failed to parse flags") + + fs := flag.NewFlagSet("veil-proxy", flag.ContinueOnError) + fs.SetOutput(out) + + profile := fs.Bool( + "profile", + false, + "Enable profiling.", + ) + port := fs.Int( + "port", + 1024, + "VSOCK port that the enclave connects to.", + ) + if err := fs.Parse(args); err != nil { + return nil, err + } + + return &config{ + profile: *profile, + port: *port, + }, nil +} + +func listenVSOCK(port uint32) (_ net.Listener, err error) { + defer errs.Wrap(&err, "failed to create VSOCK listener") + + cid, err := vsock.ContextID() + if err != nil { + return nil, err + } + return vsock.ListenContextID(cid, port, nil) +} + +func acceptLoop(ln net.Listener) { + // Print errors that occur while forwarding packets. + ch := make(chan error) + defer close(ch) + go func(ch chan error) { + for err := range ch { + log.Print(err) + } + }(ch) + + // Listen for connections from the enclave and begin forwarding packets + // once a new connection is established. At any given point, we only expect + // to have a single TCP-over-VSOCK connection with the enclave. + for { + tunDev, err := tun.SetupTunAsProxy() + if err != nil { + log.Printf("Error creating tun device: %v", err) + continue + } + log.Print("Created tun device.") + + log.Println("Waiting for new connection from enclave.") + vm, err := ln.Accept() + if err != nil { + log.Printf("Error accepting connection: %v", err) + continue + } + log.Printf("Accepted new connection from %s.", vm.RemoteAddr()) + + var wg sync.WaitGroup + wg.Add(2) + go proxy.VsockToTun(vm, tunDev, ch, &wg) + go proxy.TunToVsock(tunDev, vm, ch, &wg) + wg.Wait() + } +} + +func run(ctx context.Context, out io.Writer, args []string) (origErr error) { + _, cancel := signal.NotifyContext(ctx, os.Interrupt) + defer cancel() + + cfg, err := parseFlags(out, args) + if err != nil { + return err + } + + // Enable NAT. + if err := nat.Enable(); err != nil { + return errs.Add(err, "failed to enable NAT") + } + log.Print("Enabled NAT.") + defer func() { + errs.Join(&origErr, errs.Add(nat.Disable(), "failed to disable NAT")) + log.Print("Disabled NAT.") + }() + + // Create a VSOCK listener that listens for incoming connections from the + // enclave. + ln, err := listenVSOCK(uint32(cfg.port)) + if err != nil { + return err + } + defer func() { + errs.Join(&origErr, errs.Add(ln.Close(), "failed to close listener")) + }() + + // If desired, set up a Web server for the profiler. + if cfg.profile { + go func() { + const hostPort = "localhost:6060" + log.Printf("Starting profiling Web server at: http://%s", hostPort) + err := http.ListenAndServe(hostPort, nil) + if err != nil && err != http.ErrServerClosed { + log.Printf("Error running profiling server: %v", err) + } + }() + } + + // Accept new connections from the VSOCK listener and begin forwarding + // packets. + acceptLoop(ln) + return nil +} + +func main() { + if err := run(context.Background(), os.Stdout, os.Args[1:]); err != nil { + log.Fatalf("Failed to run proxy: %v", err) + } +} diff --git a/go.mod b/go.mod index efbc4e3..909f1fc 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23.0 require ( github.com/Amnesic-Systems/nitriding-proxy v0.1.1 + github.com/coreos/go-iptables v0.8.0 github.com/docker/docker v27.3.1+incompatible github.com/fatih/color v1.18.0 github.com/fxamacker/cbor/v2 v2.7.0 @@ -14,6 +15,7 @@ require ( github.com/moby/term v0.5.0 github.com/opencontainers/image-spec v1.1.0 github.com/stretchr/testify v1.9.0 + golang.org/x/net v0.33.0 golang.org/x/sys v0.28.0 ) @@ -22,7 +24,6 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/containerd/log v0.1.0 // indirect - github.com/coreos/go-iptables v0.8.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect @@ -49,7 +50,6 @@ require ( go.opentelemetry.io/otel/metric v1.32.0 // indirect go.opentelemetry.io/otel/sdk v1.28.0 // indirect go.opentelemetry.io/otel/trace v1.32.0 // indirect - golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/time v0.8.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect diff --git a/go.sum b/go.sum index d1a6514..a87cd07 100644 --- a/go.sum +++ b/go.sum @@ -122,8 +122,6 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -137,15 +135,12 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/errs/errs.go b/internal/errs/errs.go index a59f5bc..94783e6 100644 --- a/internal/errs/errs.go +++ b/internal/errs/errs.go @@ -29,3 +29,10 @@ func Add(err error, str string, args ...any) error { } return fmt.Errorf("%s: %w", fmt.Sprintf(str, args...), err) } + +func Join(origErr *error, new error) { + if origErr == nil { + return + } + *origErr = errors.Join(*origErr, new) +} diff --git a/internal/net/nat/nat.go b/internal/net/nat/nat.go new file mode 100644 index 0000000..ac182d5 --- /dev/null +++ b/internal/net/nat/nat.go @@ -0,0 +1,44 @@ +package nat + +import ( + "github.com/Amnesic-Systems/veil/internal/net/tun" + "github.com/coreos/go-iptables/iptables" +) + +// Enable enables our iptables NAT rules, which connect the enclave to the +// Internet. +func Enable() error { + return applyRules(true) +} + +// Disable disables our iptables NAT rules. +func Disable() error { + return applyRules(false) +} + +func applyRules(toggle bool) error { + t, err := iptables.New() + if err != nil { + return err + } + + f := t.AppendUnique + if !toggle { + f = t.DeleteIfExists + } + + var iptablesRules = [][]string{ + {"nat", "POSTROUTING", "-s", "10.0.0.0/24", "-j", "MASQUERADE"}, + {"filter", "FORWARD", "-i", tun.Name, "-s", "10.0.0.0/24", "-j", "ACCEPT"}, + {"filter", "FORWARD", "-o", tun.Name, "-d", "10.0.0.0/24", "-j", "ACCEPT"}, + } + + const table, chain, rulespec = 0, 1, 2 + for _, r := range iptablesRules { + if err := f(r[table], r[chain], r[rulespec:]...); err != nil { + return err + } + } + + return nil +} diff --git a/internal/net/proxy/proxy.go b/internal/net/proxy/proxy.go new file mode 100644 index 0000000..b77b562 --- /dev/null +++ b/internal/net/proxy/proxy.go @@ -0,0 +1,79 @@ +package proxy + +import ( + "encoding/binary" + "fmt" + "io" + "sync" + + "github.com/Amnesic-Systems/veil/internal/net/tun" +) + +const lenBufSize = 2 + +// TunToVsock forwards network packets from the tun device to our +// TCP-over-VSOCK connection. The function keeps on forwarding packets until we +// encounter an error or EOF. Errors (including EOF) are written to the given +// channel. +func TunToVsock(from io.ReadCloser, to io.WriteCloser, ch chan error, wg *sync.WaitGroup) { + defer to.Close() + defer wg.Done() + var ( + err error + pktLenBuf = make([]byte, lenBufSize) + pktBuf = make([]byte, tun.MTU) + ) + + for { + // Read a network packet from the tun interface. + nr, rerr := from.Read(pktBuf) + if nr > 0 { + // Forward the network packet to our TCP-over-VSOCK connection. + binary.BigEndian.PutUint16(pktLenBuf, uint16(nr)) + if _, werr := to.Write(append(pktLenBuf, pktBuf[:nr]...)); werr != nil { + err = werr + break + } + } + if rerr != nil { + err = rerr + break + } + } + ch <- fmt.Errorf("stopped tun-to-vsock forwarding: %w", err) +} + +// VsockToTun forwards network packets from our TCP-over-VSOCK connection to +// the tun interface. The function keeps on forwarding packets until we +// encounter an error or EOF. Errors (including EOF) are written to the given +// channel. +func VsockToTun(from io.ReadCloser, to io.WriteCloser, ch chan error, wg *sync.WaitGroup) { + defer to.Close() + defer wg.Done() + var ( + err error + pktLen uint16 + pktLenBuf = make([]byte, lenBufSize) + pktBuf = make([]byte, tun.MTU) + ) + + for { + // Read the length prefix that tells us the size of the subsequent + // packet. + if _, err = io.ReadFull(from, pktLenBuf); err != nil { + break + } + pktLen = binary.BigEndian.Uint16(pktLenBuf) + + // Read the packet. + if _, err = io.ReadFull(from, pktBuf[:pktLen]); err != nil { + break + } + + // Forward the packet to the tun interface. + if _, err = to.Write(pktBuf[:pktLen]); err != nil { + break + } + } + ch <- fmt.Errorf("stopped vsock-to-tun forwarding: %w", err) +} diff --git a/internal/net/proxy/proxy_test.go b/internal/net/proxy/proxy_test.go new file mode 100644 index 0000000..6cf3e52 --- /dev/null +++ b/internal/net/proxy/proxy_test.go @@ -0,0 +1,80 @@ +package proxy + +import ( + "bytes" + "crypto/rand" + "errors" + "io" + "net" + "sync" + "testing" + + "github.com/Amnesic-Systems/veil/internal/net/tun" + "golang.org/x/net/nettest" +) + +func assertEq(t *testing.T, is, should interface{}) { + t.Helper() + if should != is { + t.Fatalf("Expected value\n%v\nbut got\n%v", should, is) + } +} + +// buffer implements io.ReadWriteCloser. +type buffer struct { + *bytes.Buffer +} + +func (b *buffer) Close() error { + return nil +} + +func TestNettest(t *testing.T) { + mkPipe := func() (c1, c2 net.Conn, stop func(), err error) { + var ( + in, out = net.Pipe() + fwd1, fwd2 = net.Pipe() + wg = sync.WaitGroup{} + ch = make(chan error) + ) + wg.Add(2) + go TunToVsock(in, fwd1, ch, &wg) + go VsockToTun(fwd2, out, ch, &wg) + return in, out, func() {}, nil + } + nettest.TestConn(t, nettest.MakePipe(mkPipe)) +} + +func TestAToB(t *testing.T) { + var ( + err error + wg sync.WaitGroup + ch = make(chan error) + conn1, conn2 = net.Pipe() + sendBuf = make([]byte, tun.MTU*2) + recvBuf = &buffer{ + Buffer: new(bytes.Buffer), + } + ) + + // We only expect to see errors containing io.EOF. + go func() { + for err := range ch { + assertEq(t, errors.Is(err, io.EOF), true) + } + }() + + // Fill sendBuf with random data. + _, err = rand.Read(sendBuf) + assertEq(t, err, nil) + + wg.Add(2) + go TunToVsock(io.NopCloser(bytes.NewReader(sendBuf)), conn1, ch, &wg) + go VsockToTun(conn2, recvBuf, ch, &wg) + wg.Wait() + + assertEq(t, bytes.Equal( + sendBuf, + recvBuf.Bytes(), + ), true) +} diff --git a/internal/net/tun/tun.go b/internal/net/tun/tun.go new file mode 100644 index 0000000..6dd5fcd --- /dev/null +++ b/internal/net/tun/tun.go @@ -0,0 +1,6 @@ +package tun + +const ( + Name = "tun0" + MTU = 65535 // The maximum-allowed MTU for the tun interface. +) diff --git a/internal/net/tun/tun_darwin.go b/internal/net/tun/tun_darwin.go new file mode 100644 index 0000000..3d3ad29 --- /dev/null +++ b/internal/net/tun/tun_darwin.go @@ -0,0 +1,17 @@ +package tun + +import ( + "os" +) + +// veil-proxy does not support macOS but we can at least make it compile by +// implementing the following functions. +const err = "not implemented on darwin" + +func SetupTunAsProxy() (*os.File, error) { + panic(err) +} + +func SetupTunAsEnclave() (*os.File, error) { + panic(err) +} diff --git a/internal/net/tun/tun_linux.go b/internal/net/tun/tun_linux.go new file mode 100644 index 0000000..0efd089 --- /dev/null +++ b/internal/net/tun/tun_linux.go @@ -0,0 +1,112 @@ +package tun + +import ( + "fmt" + "net" + "os" + "unsafe" + + "github.com/milosgajdos/tenus" + "golang.org/x/sys/unix" +) + +const ( + asEnclave = iota + asProxy +) + +type ifReq struct { + Name [0x10]byte + Flags uint16 + pad [0x28 - 0x10 - 2]byte +} + +// SetupTunAsProxy sets up a tun interface and returns a ready-to-use file +// descriptor. +func SetupTunAsProxy() (*os.File, error) { + return setupTun(asProxy) +} + +// SetupTunAsEnclave sets up a tun interface and returns a ready-to-use file +// descriptor. +func SetupTunAsEnclave() (*os.File, error) { + return setupTun(asEnclave) +} + +// setupTun creates and configures a tun interface. The given typ must be +// asEnclave or asProxy. +func setupTun(typ int) (*os.File, error) { + fd, err := createTun() + if err != nil { + return nil, err + } + if err := configureTun(typ); err != nil { + return nil, err + } + + return fd, nil +} + +// createTun returns a ready-to-use file descriptor for our tun interface. +func createTun() (*os.File, error) { + const tunPath = "/dev/net/tun" + tunfd, err := unix.Open(tunPath, os.O_RDWR, 0) + if err != nil { + return nil, err + } + + ifr := ifReq{ + Flags: unix.IFF_TUN | unix.IFF_NO_PI, + } + copy(ifr.Name[:], Name) + + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(tunfd), + uintptr(unix.TUNSETIFF), + uintptr(unsafe.Pointer(&ifr)), + ) + if errno != 0 { + return nil, errno + } + unix.SetNonblock(tunfd, true) + + return os.NewFile(uintptr(tunfd), tunPath), nil +} + +// configureTun configures our tun device. The function assigns an IP address, +// sets the link MTU, and may set the default gateway, after which the device +// is ready for use. +func configureTun(typ int) error { + cidrStr := "10.0.0.1/24" + if typ == asEnclave { + cidrStr = "10.0.0.2/24" + } + + link, err := tenus.NewLinkFrom(Name) + if err != nil { + return fmt.Errorf("failed to retrieve link: %w", err) + } + cidr, network, err := net.ParseCIDR(cidrStr) + if err != nil { + return fmt.Errorf("failed to parse CIDR: %w", err) + } + if err = link.SetLinkIp(cidr, network); err != nil { + return fmt.Errorf("failed to set link address: %w", err) + } + if err := link.SetLinkMTU(MTU); err != nil { + return fmt.Errorf("failed to set link MTU: %w", err) + } + // Set the enclave's default gateway to the proxy's IP address. + if typ == asEnclave { + gw := net.ParseIP("10.0.0.1") + if err := link.SetLinkDefaultGw(&gw); err != nil { + return fmt.Errorf("failed to set default gateway: %w", err) + } + } + if err := link.SetLinkUp(); err != nil { + return fmt.Errorf("failed to bring up link: %w", err) + } + + return nil +} diff --git a/vendor/golang.org/x/net/nettest/conntest.go b/vendor/golang.org/x/net/nettest/conntest.go new file mode 100644 index 0000000..4297d40 --- /dev/null +++ b/vendor/golang.org/x/net/nettest/conntest.go @@ -0,0 +1,467 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "bytes" + "encoding/binary" + "io" + "math/rand" + "net" + "runtime" + "sync" + "testing" + "time" +) + +// MakePipe creates a connection between two endpoints and returns the pair +// as c1 and c2, such that anything written to c1 is read by c2 and vice-versa. +// The stop function closes all resources, including c1, c2, and the underlying +// net.Listener (if there is one), and should not be nil. +type MakePipe func() (c1, c2 net.Conn, stop func(), err error) + +// TestConn tests that a net.Conn implementation properly satisfies the interface. +// The tests should not produce any false positives, but may experience +// false negatives. Thus, some issues may only be detected when the test is +// run multiple times. For maximal effectiveness, run the tests under the +// race detector. +func TestConn(t *testing.T, mp MakePipe) { + t.Run("BasicIO", func(t *testing.T) { timeoutWrapper(t, mp, testBasicIO) }) + t.Run("PingPong", func(t *testing.T) { timeoutWrapper(t, mp, testPingPong) }) + t.Run("RacyRead", func(t *testing.T) { timeoutWrapper(t, mp, testRacyRead) }) + t.Run("RacyWrite", func(t *testing.T) { timeoutWrapper(t, mp, testRacyWrite) }) + t.Run("ReadTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testReadTimeout) }) + t.Run("WriteTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testWriteTimeout) }) + t.Run("PastTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testPastTimeout) }) + t.Run("PresentTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testPresentTimeout) }) + t.Run("FutureTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testFutureTimeout) }) + t.Run("CloseTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testCloseTimeout) }) + t.Run("ConcurrentMethods", func(t *testing.T) { timeoutWrapper(t, mp, testConcurrentMethods) }) +} + +type connTester func(t *testing.T, c1, c2 net.Conn) + +func timeoutWrapper(t *testing.T, mp MakePipe, f connTester) { + t.Helper() + c1, c2, stop, err := mp() + if err != nil { + t.Fatalf("unable to make pipe: %v", err) + } + var once sync.Once + defer once.Do(func() { stop() }) + timer := time.AfterFunc(time.Minute, func() { + once.Do(func() { + t.Error("test timed out; terminating pipe") + stop() + }) + }) + defer timer.Stop() + f(t, c1, c2) +} + +// testBasicIO tests that the data sent on c1 is properly received on c2. +func testBasicIO(t *testing.T, c1, c2 net.Conn) { + want := make([]byte, 1<<20) + rand.New(rand.NewSource(0)).Read(want) + + dataCh := make(chan []byte) + go func() { + rd := bytes.NewReader(want) + if err := chunkedCopy(c1, rd); err != nil { + t.Errorf("unexpected c1.Write error: %v", err) + } + if err := c1.Close(); err != nil { + t.Errorf("unexpected c1.Close error: %v", err) + } + }() + + go func() { + wr := new(bytes.Buffer) + if err := chunkedCopy(wr, c2); err != nil { + t.Errorf("unexpected c2.Read error: %v", err) + } + if err := c2.Close(); err != nil { + t.Errorf("unexpected c2.Close error: %v", err) + } + dataCh <- wr.Bytes() + }() + + if got := <-dataCh; !bytes.Equal(got, want) { + t.Error("transmitted data differs") + } +} + +// testPingPong tests that the two endpoints can synchronously send data to +// each other in a typical request-response pattern. +func testPingPong(t *testing.T, c1, c2 net.Conn) { + var wg sync.WaitGroup + defer wg.Wait() + + pingPonger := func(c net.Conn) { + defer wg.Done() + buf := make([]byte, 8) + var prev uint64 + for { + if _, err := io.ReadFull(c, buf); err != nil { + if err == io.EOF { + break + } + t.Errorf("unexpected Read error: %v", err) + } + + v := binary.LittleEndian.Uint64(buf) + binary.LittleEndian.PutUint64(buf, v+1) + if prev != 0 && prev+2 != v { + t.Errorf("mismatching value: got %d, want %d", v, prev+2) + } + prev = v + if v == 1000 { + break + } + + if _, err := c.Write(buf); err != nil { + t.Errorf("unexpected Write error: %v", err) + break + } + } + if err := c.Close(); err != nil { + t.Errorf("unexpected Close error: %v", err) + } + } + + wg.Add(2) + go pingPonger(c1) + go pingPonger(c2) + + // Start off the chain reaction. + if _, err := c1.Write(make([]byte, 8)); err != nil { + t.Errorf("unexpected c1.Write error: %v", err) + } +} + +// testRacyRead tests that it is safe to mutate the input Read buffer +// immediately after cancelation has occurred. +func testRacyRead(t *testing.T, c1, c2 net.Conn) { + go chunkedCopy(c2, rand.New(rand.NewSource(0))) + + var wg sync.WaitGroup + defer wg.Wait() + + c1.SetReadDeadline(time.Now().Add(time.Millisecond)) + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + b1 := make([]byte, 1024) + b2 := make([]byte, 1024) + for j := 0; j < 100; j++ { + _, err := c1.Read(b1) + copy(b1, b2) // Mutate b1 to trigger potential race + if err != nil { + checkForTimeoutError(t, err) + c1.SetReadDeadline(time.Now().Add(time.Millisecond)) + } + } + }() + } +} + +// testRacyWrite tests that it is safe to mutate the input Write buffer +// immediately after cancelation has occurred. +func testRacyWrite(t *testing.T, c1, c2 net.Conn) { + go chunkedCopy(io.Discard, c2) + + var wg sync.WaitGroup + defer wg.Wait() + + c1.SetWriteDeadline(time.Now().Add(time.Millisecond)) + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + b1 := make([]byte, 1024) + b2 := make([]byte, 1024) + for j := 0; j < 100; j++ { + _, err := c1.Write(b1) + copy(b1, b2) // Mutate b1 to trigger potential race + if err != nil { + checkForTimeoutError(t, err) + c1.SetWriteDeadline(time.Now().Add(time.Millisecond)) + } + } + }() + } +} + +// testReadTimeout tests that Read timeouts do not affect Write. +func testReadTimeout(t *testing.T, c1, c2 net.Conn) { + go chunkedCopy(io.Discard, c2) + + c1.SetReadDeadline(aLongTimeAgo) + _, err := c1.Read(make([]byte, 1024)) + checkForTimeoutError(t, err) + if _, err := c1.Write(make([]byte, 1024)); err != nil { + t.Errorf("unexpected Write error: %v", err) + } +} + +// testWriteTimeout tests that Write timeouts do not affect Read. +func testWriteTimeout(t *testing.T, c1, c2 net.Conn) { + go chunkedCopy(c2, rand.New(rand.NewSource(0))) + + c1.SetWriteDeadline(aLongTimeAgo) + _, err := c1.Write(make([]byte, 1024)) + checkForTimeoutError(t, err) + if _, err := c1.Read(make([]byte, 1024)); err != nil { + t.Errorf("unexpected Read error: %v", err) + } +} + +// testPastTimeout tests that a deadline set in the past immediately times out +// Read and Write requests. +func testPastTimeout(t *testing.T, c1, c2 net.Conn) { + go chunkedCopy(c2, c2) + + testRoundtrip(t, c1) + + c1.SetDeadline(aLongTimeAgo) + n, err := c1.Write(make([]byte, 1024)) + if n != 0 { + t.Errorf("unexpected Write count: got %d, want 0", n) + } + checkForTimeoutError(t, err) + n, err = c1.Read(make([]byte, 1024)) + if n != 0 { + t.Errorf("unexpected Read count: got %d, want 0", n) + } + checkForTimeoutError(t, err) + + testRoundtrip(t, c1) +} + +// testPresentTimeout tests that a past deadline set while there are pending +// Read and Write operations immediately times out those operations. +func testPresentTimeout(t *testing.T, c1, c2 net.Conn) { + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(3) + + deadlineSet := make(chan bool, 1) + go func() { + defer wg.Done() + time.Sleep(100 * time.Millisecond) + deadlineSet <- true + c1.SetReadDeadline(aLongTimeAgo) + c1.SetWriteDeadline(aLongTimeAgo) + }() + go func() { + defer wg.Done() + n, err := c1.Read(make([]byte, 1024)) + if n != 0 { + t.Errorf("unexpected Read count: got %d, want 0", n) + } + checkForTimeoutError(t, err) + if len(deadlineSet) == 0 { + t.Error("Read timed out before deadline is set") + } + }() + go func() { + defer wg.Done() + var err error + for err == nil { + _, err = c1.Write(make([]byte, 1024)) + } + checkForTimeoutError(t, err) + if len(deadlineSet) == 0 { + t.Error("Write timed out before deadline is set") + } + }() +} + +// testFutureTimeout tests that a future deadline will eventually time out +// Read and Write operations. +func testFutureTimeout(t *testing.T, c1, c2 net.Conn) { + var wg sync.WaitGroup + wg.Add(2) + + c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) + go func() { + defer wg.Done() + _, err := c1.Read(make([]byte, 1024)) + checkForTimeoutError(t, err) + }() + go func() { + defer wg.Done() + var err error + for err == nil { + _, err = c1.Write(make([]byte, 1024)) + } + checkForTimeoutError(t, err) + }() + wg.Wait() + + go chunkedCopy(c2, c2) + resyncConn(t, c1) + testRoundtrip(t, c1) +} + +// testCloseTimeout tests that calling Close immediately times out pending +// Read and Write operations. +func testCloseTimeout(t *testing.T, c1, c2 net.Conn) { + go chunkedCopy(c2, c2) + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(3) + + // Test for cancelation upon connection closure. + c1.SetDeadline(neverTimeout) + go func() { + defer wg.Done() + time.Sleep(100 * time.Millisecond) + c1.Close() + }() + go func() { + defer wg.Done() + var err error + buf := make([]byte, 1024) + for err == nil { + _, err = c1.Read(buf) + } + }() + go func() { + defer wg.Done() + var err error + buf := make([]byte, 1024) + for err == nil { + _, err = c1.Write(buf) + } + }() +} + +// testConcurrentMethods tests that the methods of net.Conn can safely +// be called concurrently. +func testConcurrentMethods(t *testing.T, c1, c2 net.Conn) { + if runtime.GOOS == "plan9" { + t.Skip("skipping on plan9; see https://golang.org/issue/20489") + } + go chunkedCopy(c2, c2) + + // The results of the calls may be nonsensical, but this should + // not trigger a race detector warning. + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(7) + go func() { + defer wg.Done() + c1.Read(make([]byte, 1024)) + }() + go func() { + defer wg.Done() + c1.Write(make([]byte, 1024)) + }() + go func() { + defer wg.Done() + c1.SetDeadline(time.Now().Add(10 * time.Millisecond)) + }() + go func() { + defer wg.Done() + c1.SetReadDeadline(aLongTimeAgo) + }() + go func() { + defer wg.Done() + c1.SetWriteDeadline(aLongTimeAgo) + }() + go func() { + defer wg.Done() + c1.LocalAddr() + }() + go func() { + defer wg.Done() + c1.RemoteAddr() + }() + } + wg.Wait() // At worst, the deadline is set 10ms into the future + + resyncConn(t, c1) + testRoundtrip(t, c1) +} + +// checkForTimeoutError checks that the error satisfies the Error interface +// and that Timeout returns true. +func checkForTimeoutError(t *testing.T, err error) { + t.Helper() + if nerr, ok := err.(net.Error); ok { + if !nerr.Timeout() { + if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && t.Name() == "TestTestConn/TCP/RacyRead" { + t.Logf("ignoring known failure mode on windows/arm64; see https://go.dev/issue/52893") + } else { + t.Errorf("got error: %v, want err.Timeout() = true", nerr) + } + } + } else { + t.Errorf("got %T: %v, want net.Error", err, err) + } +} + +// testRoundtrip writes something into c and reads it back. +// It assumes that everything written into c is echoed back to itself. +func testRoundtrip(t *testing.T, c net.Conn) { + t.Helper() + if err := c.SetDeadline(neverTimeout); err != nil { + t.Errorf("roundtrip SetDeadline error: %v", err) + } + + const s = "Hello, world!" + buf := []byte(s) + if _, err := c.Write(buf); err != nil { + t.Errorf("roundtrip Write error: %v", err) + } + if _, err := io.ReadFull(c, buf); err != nil { + t.Errorf("roundtrip Read error: %v", err) + } + if string(buf) != s { + t.Errorf("roundtrip data mismatch: got %q, want %q", buf, s) + } +} + +// resyncConn resynchronizes the connection into a sane state. +// It assumes that everything written into c is echoed back to itself. +// It assumes that 0xff is not currently on the wire or in the read buffer. +func resyncConn(t *testing.T, c net.Conn) { + t.Helper() + c.SetDeadline(neverTimeout) + errCh := make(chan error) + go func() { + _, err := c.Write([]byte{0xff}) + errCh <- err + }() + buf := make([]byte, 1024) + for { + n, err := c.Read(buf) + if n > 0 && bytes.IndexByte(buf[:n], 0xff) == n-1 { + break + } + if err != nil { + t.Errorf("unexpected Read error: %v", err) + break + } + } + if err := <-errCh; err != nil { + t.Errorf("unexpected Write error: %v", err) + } +} + +// chunkedCopy copies from r to w in fixed-width chunks to avoid +// causing a Write that exceeds the maximum packet size for packet-based +// connections like "unixpacket". +// We assume that the maximum packet size is at least 1024. +func chunkedCopy(w io.Writer, r io.Reader) error { + b := make([]byte, 1024) + _, err := io.CopyBuffer(struct{ io.Writer }{w}, struct{ io.Reader }{r}, b) + return err +} diff --git a/vendor/golang.org/x/net/nettest/nettest.go b/vendor/golang.org/x/net/nettest/nettest.go new file mode 100644 index 0000000..37e6dcb --- /dev/null +++ b/vendor/golang.org/x/net/nettest/nettest.go @@ -0,0 +1,344 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package nettest provides utilities for network testing. +package nettest + +import ( + "errors" + "fmt" + "net" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync" + "time" +) + +var ( + stackOnce sync.Once + ipv4Enabled bool + canListenTCP4OnLoopback bool + ipv6Enabled bool + canListenTCP6OnLoopback bool + unStrmDgramEnabled bool + rawSocketSess bool + + aLongTimeAgo = time.Unix(233431200, 0) + neverTimeout = time.Time{} + + errNoAvailableInterface = errors.New("no available interface") + errNoAvailableAddress = errors.New("no available address") +) + +func probeStack() { + if _, err := RoutedInterface("ip4", net.FlagUp); err == nil { + ipv4Enabled = true + } + if ln, err := net.Listen("tcp4", "127.0.0.1:0"); err == nil { + ln.Close() + canListenTCP4OnLoopback = true + } + if _, err := RoutedInterface("ip6", net.FlagUp); err == nil { + ipv6Enabled = true + } + if ln, err := net.Listen("tcp6", "[::1]:0"); err == nil { + ln.Close() + canListenTCP6OnLoopback = true + } + rawSocketSess = supportsRawSocket() + switch runtime.GOOS { + case "aix": + // Unix network isn't properly working on AIX 7.2 with + // Technical Level < 2. + out, _ := exec.Command("oslevel", "-s").Output() + if len(out) >= len("7200-XX-ZZ-YYMM") { // AIX 7.2, Tech Level XX, Service Pack ZZ, date YYMM + ver := string(out[:4]) + tl, _ := strconv.Atoi(string(out[5:7])) + unStrmDgramEnabled = ver > "7200" || (ver == "7200" && tl >= 2) + } + default: + unStrmDgramEnabled = true + } +} + +func unixStrmDgramEnabled() bool { + stackOnce.Do(probeStack) + return unStrmDgramEnabled +} + +// SupportsIPv4 reports whether the platform supports IPv4 networking +// functionality. +func SupportsIPv4() bool { + stackOnce.Do(probeStack) + return ipv4Enabled +} + +// SupportsIPv6 reports whether the platform supports IPv6 networking +// functionality. +func SupportsIPv6() bool { + stackOnce.Do(probeStack) + return ipv6Enabled +} + +// SupportsRawSocket reports whether the current session is available +// to use raw sockets. +func SupportsRawSocket() bool { + stackOnce.Do(probeStack) + return rawSocketSess +} + +// TestableNetwork reports whether network is testable on the current +// platform configuration. +// +// See func Dial of the standard library for the supported networks. +func TestableNetwork(network string) bool { + ss := strings.Split(network, ":") + switch ss[0] { + case "ip+nopriv": + // This is an internal network name for testing on the + // package net of the standard library. + switch runtime.GOOS { + case "android", "fuchsia", "hurd", "ios", "js", "nacl", "plan9", "wasip1", "windows": + return false + } + case "ip", "ip4", "ip6": + switch runtime.GOOS { + case "fuchsia", "hurd", "js", "nacl", "plan9", "wasip1": + return false + default: + if os.Getuid() != 0 { + return false + } + } + case "unix", "unixgram": + switch runtime.GOOS { + case "android", "fuchsia", "hurd", "ios", "js", "nacl", "plan9", "wasip1", "windows": + return false + case "aix": + return unixStrmDgramEnabled() + } + case "unixpacket": + switch runtime.GOOS { + case "aix", "android", "fuchsia", "hurd", "darwin", "ios", "js", "nacl", "plan9", "wasip1", "windows", "zos": + return false + } + } + switch ss[0] { + case "tcp4", "udp4", "ip4": + return SupportsIPv4() + case "tcp6", "udp6", "ip6": + return SupportsIPv6() + } + return true +} + +// TestableAddress reports whether address of network is testable on +// the current platform configuration. +func TestableAddress(network, address string) bool { + switch ss := strings.Split(network, ":"); ss[0] { + case "unix", "unixgram", "unixpacket": + // Abstract unix domain sockets, a Linux-ism. + if address[0] == '@' && runtime.GOOS != "linux" { + return false + } + } + return true +} + +// NewLocalListener returns a listener which listens to a loopback IP +// address or local file system path. +// +// The provided network must be "tcp", "tcp4", "tcp6", "unix" or +// "unixpacket". +func NewLocalListener(network string) (net.Listener, error) { + stackOnce.Do(probeStack) + switch network { + case "tcp": + if canListenTCP4OnLoopback { + if ln, err := net.Listen("tcp4", "127.0.0.1:0"); err == nil { + return ln, nil + } + } + if canListenTCP6OnLoopback { + return net.Listen("tcp6", "[::1]:0") + } + case "tcp4": + if canListenTCP4OnLoopback { + return net.Listen("tcp4", "127.0.0.1:0") + } + case "tcp6": + if canListenTCP6OnLoopback { + return net.Listen("tcp6", "[::1]:0") + } + case "unix", "unixpacket": + path, err := LocalPath() + if err != nil { + return nil, err + } + return net.Listen(network, path) + } + return nil, fmt.Errorf("%s is not supported on %s/%s", network, runtime.GOOS, runtime.GOARCH) +} + +// NewLocalPacketListener returns a packet listener which listens to a +// loopback IP address or local file system path. +// +// The provided network must be "udp", "udp4", "udp6" or "unixgram". +func NewLocalPacketListener(network string) (net.PacketConn, error) { + stackOnce.Do(probeStack) + switch network { + case "udp": + if canListenTCP4OnLoopback { + if c, err := net.ListenPacket("udp4", "127.0.0.1:0"); err == nil { + return c, nil + } + } + if canListenTCP6OnLoopback { + return net.ListenPacket("udp6", "[::1]:0") + } + case "udp4": + if canListenTCP4OnLoopback { + return net.ListenPacket("udp4", "127.0.0.1:0") + } + case "udp6": + if canListenTCP6OnLoopback { + return net.ListenPacket("udp6", "[::1]:0") + } + case "unixgram": + path, err := LocalPath() + if err != nil { + return nil, err + } + return net.ListenPacket(network, path) + } + return nil, fmt.Errorf("%s is not supported on %s/%s", network, runtime.GOOS, runtime.GOARCH) +} + +// LocalPath returns a local path that can be used for Unix-domain +// protocol testing. +func LocalPath() (string, error) { + dir := "" + if runtime.GOOS == "darwin" { + dir = "/tmp" + } + f, err := os.CreateTemp(dir, "go-nettest") + if err != nil { + return "", err + } + path := f.Name() + f.Close() + os.Remove(path) + return path, nil +} + +// MulticastSource returns a unicast IP address on ifi when ifi is an +// IP multicast-capable network interface. +// +// The provided network must be "ip", "ip4" or "ip6". +func MulticastSource(network string, ifi *net.Interface) (net.IP, error) { + switch network { + case "ip", "ip4", "ip6": + default: + return nil, errNoAvailableAddress + } + if ifi == nil || ifi.Flags&net.FlagUp == 0 || ifi.Flags&net.FlagMulticast == 0 { + return nil, errNoAvailableAddress + } + ip, ok := hasRoutableIP(network, ifi) + if !ok { + return nil, errNoAvailableAddress + } + return ip, nil +} + +// LoopbackInterface returns an available logical network interface +// for loopback test. +func LoopbackInterface() (*net.Interface, error) { + ift, err := net.Interfaces() + if err != nil { + return nil, errNoAvailableInterface + } + for _, ifi := range ift { + if ifi.Flags&net.FlagLoopback != 0 && ifi.Flags&net.FlagUp != 0 { + return &ifi, nil + } + } + return nil, errNoAvailableInterface +} + +// RoutedInterface returns a network interface that can route IP +// traffic and satisfies flags. +// +// The provided network must be "ip", "ip4" or "ip6". +func RoutedInterface(network string, flags net.Flags) (*net.Interface, error) { + switch network { + case "ip", "ip4", "ip6": + default: + return nil, errNoAvailableInterface + } + ift, err := net.Interfaces() + if err != nil { + return nil, errNoAvailableInterface + } + for _, ifi := range ift { + if ifi.Flags&flags != flags { + continue + } + if _, ok := hasRoutableIP(network, &ifi); !ok { + continue + } + return &ifi, nil + } + return nil, errNoAvailableInterface +} + +func hasRoutableIP(network string, ifi *net.Interface) (net.IP, bool) { + ifat, err := ifi.Addrs() + if err != nil { + return nil, false + } + for _, ifa := range ifat { + switch ifa := ifa.(type) { + case *net.IPAddr: + if ip, ok := routableIP(network, ifa.IP); ok { + return ip, true + } + case *net.IPNet: + if ip, ok := routableIP(network, ifa.IP); ok { + return ip, true + } + } + } + return nil, false +} + +func routableIP(network string, ip net.IP) (net.IP, bool) { + if !ip.IsLoopback() && !ip.IsLinkLocalUnicast() && !ip.IsGlobalUnicast() { + return nil, false + } + switch network { + case "ip4": + if ip := ip.To4(); ip != nil { + return ip, true + } + case "ip6": + if ip.IsLoopback() { // addressing scope of the loopback address depends on each implementation + return nil, false + } + if ip := ip.To16(); ip != nil && ip.To4() == nil { + return ip, true + } + default: + if ip := ip.To4(); ip != nil { + return ip, true + } + if ip := ip.To16(); ip != nil { + return ip, true + } + } + return nil, false +} diff --git a/vendor/golang.org/x/net/nettest/nettest_stub.go b/vendor/golang.org/x/net/nettest/nettest_stub.go new file mode 100644 index 0000000..1725b6a --- /dev/null +++ b/vendor/golang.org/x/net/nettest/nettest_stub.go @@ -0,0 +1,11 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris && !windows && !zos + +package nettest + +func supportsRawSocket() bool { + return false +} diff --git a/vendor/golang.org/x/net/nettest/nettest_unix.go b/vendor/golang.org/x/net/nettest/nettest_unix.go new file mode 100644 index 0000000..9ba269d --- /dev/null +++ b/vendor/golang.org/x/net/nettest/nettest_unix.go @@ -0,0 +1,21 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos + +package nettest + +import "syscall" + +func supportsRawSocket() bool { + for _, af := range []int{syscall.AF_INET, syscall.AF_INET6} { + s, err := syscall.Socket(af, syscall.SOCK_RAW, 0) + if err != nil { + continue + } + syscall.Close(s) + return true + } + return false +} diff --git a/vendor/golang.org/x/net/nettest/nettest_windows.go b/vendor/golang.org/x/net/nettest/nettest_windows.go new file mode 100644 index 0000000..4939964 --- /dev/null +++ b/vendor/golang.org/x/net/nettest/nettest_windows.go @@ -0,0 +1,26 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import "syscall" + +func supportsRawSocket() bool { + // From http://msdn.microsoft.com/en-us/library/windows/desktop/ms740548.aspx: + // Note: To use a socket of type SOCK_RAW requires administrative privileges. + // Users running Winsock applications that use raw sockets must be a member of + // the Administrators group on the local computer, otherwise raw socket calls + // will fail with an error code of WSAEACCES. On Windows Vista and later, access + // for raw sockets is enforced at socket creation. In earlier versions of Windows, + // access for raw sockets is enforced during other socket operations. + for _, af := range []int{syscall.AF_INET, syscall.AF_INET6} { + s, err := syscall.Socket(af, syscall.SOCK_RAW, 0) + if err != nil { + continue + } + syscall.Closesocket(s) + return true + } + return false +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 50923a1..23eb9e9 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -174,6 +174,7 @@ go.opentelemetry.io/otel/trace/embedded # golang.org/x/net v0.33.0 ## explicit; go 1.18 golang.org/x/net/bpf +golang.org/x/net/nettest # golang.org/x/sync v0.8.0 ## explicit; go 1.18 golang.org/x/sync/errgroup