From f7cadbcdfbb84d367e27b5af32e89c138d72d9d7 Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Wed, 24 Nov 2021 19:53:49 +0300 Subject: [PATCH] fix: handle duplicate peer updates Don't send peer updates to Wireguard when it's not required. Make all logging go via zap, convert Wireguard internal logger to use zap. Signed-off-by: Andrey Smirnov --- cmd/siderolink-agent/main.go | 8 ++- cmd/siderolink-agent/siderolink.go | 5 +- go.mod | 3 + go.sum | 7 +++ pkg/wireguard/wireguard.go | 98 ++++++++++++++++++++++++------ 5 files changed, 99 insertions(+), 22 deletions(-) diff --git a/cmd/siderolink-agent/main.go b/cmd/siderolink-agent/main.go index 377489f..4e06280 100644 --- a/cmd/siderolink-agent/main.go +++ b/cmd/siderolink-agent/main.go @@ -12,6 +12,7 @@ import ( "os" "os/signal" + "go.uber.org/zap" "golang.org/x/sync/errgroup" "google.golang.org/grpc" ) @@ -31,9 +32,14 @@ func main() { } func run(ctx context.Context) error { + logger, err := zap.NewDevelopment() + if err != nil { + return fmt.Errorf("error creating logger") + } + eg, ctx := errgroup.WithContext(ctx) - if err := sideroLink(ctx, eg); err != nil { + if err := sideroLink(ctx, eg, logger); err != nil { return fmt.Errorf("SideroLink: %w", err) } diff --git a/cmd/siderolink-agent/siderolink.go b/cmd/siderolink-agent/siderolink.go index 0597dea..e7ed8a6 100644 --- a/cmd/siderolink-agent/siderolink.go +++ b/cmd/siderolink-agent/siderolink.go @@ -9,6 +9,7 @@ import ( "fmt" "net" + "go.uber.org/zap" "golang.org/x/sync/errgroup" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -24,7 +25,7 @@ var sideroLinkFlags struct { apiEndpoint string } -func sideroLink(ctx context.Context, eg *errgroup.Group) error { +func sideroLink(ctx context.Context, eg *errgroup.Group, logger *zap.Logger) error { lis, err := net.Listen("tcp", sideroLinkFlags.apiEndpoint) if err != nil { return fmt.Errorf("error listening for gRPC API: %w", err) @@ -61,7 +62,7 @@ func sideroLink(ctx context.Context, eg *errgroup.Group) error { pb.RegisterProvisionServiceServer(s, srv) eg.Go(func() error { - return wgDevice.Run(ctx, srv) + return wgDevice.Run(ctx, logger, srv) }) eg.Go(func() error { diff --git a/go.mod b/go.mod index 6c2a11f..8b746fe 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786 github.com/stretchr/testify v1.7.0 github.com/talos-systems/talos/pkg/machinery v0.14.0-alpha.1.0.20211118180932-1ffa8e048008 + go.uber.org/zap v1.18.1 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.zx2c4.com/wireguard v0.0.0-20211109020618-685490f568cf golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211109202428-0073765f69ba @@ -31,6 +32,8 @@ require ( github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.7.0 // indirect go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa // indirect diff --git a/go.sum b/go.sum index a4f3158..b447a64 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,7 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -215,12 +216,16 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.18.1 h1:CSUJ2mjFszzEWt4CdKISEuChVIXGBn3lAPwkRGyVrc4= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go4.org/intern v0.0.0-20211027215823-ae77deb06f29 h1:UXLjNohABv4S58tHmeuIZDO6e3mHpW2Dx33gaNt03LE= go4.org/intern v0.0.0-20211027215823-ae77deb06f29/go.mod h1:cS2ma+47FKrLPdXFpr7CuxiTW3eyJbWew4qx0qtQWDA= @@ -243,6 +248,7 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -455,6 +461,7 @@ gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/wireguard/wireguard.go b/pkg/wireguard/wireguard.go index 2cdc471..612ae26 100644 --- a/pkg/wireguard/wireguard.go +++ b/pkg/wireguard/wireguard.go @@ -12,6 +12,7 @@ import ( "os" "github.com/jsimonetti/rtnetlink/rtnl" + "go.uber.org/zap" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" @@ -56,7 +57,7 @@ func NewDevice(address netaddr.IPPrefix, privateKey wgtypes.Key, listenPort uint } // Run the device. -func (dev *Device) Run(ctx context.Context, peers PeerSource) error { +func (dev *Device) Run(ctx context.Context, logger *zap.Logger, peers PeerSource) error { client, err := wgctrl.New() if err != nil { return fmt.Errorf("error initializing Wireguard client: %w", err) @@ -71,10 +72,10 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error { defer rtnlClient.Close() //nolint:errcheck - logger := device.NewLogger( - device.LogLevelVerbose, - fmt.Sprintf("(%s) ", interfaceName), - ) + wgLogger := &device.Logger{ + Verbosef: logger.Sugar().Debugf, + Errorf: logger.Sugar().Errorf, + } uapi, err := ipc.UAPIListen(interfaceName, dev.fileUAPI) if err != nil { @@ -83,7 +84,7 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error { defer uapi.Close() //nolint:errcheck - device := device.NewDevice(dev.tun, conn.NewDefaultBind(), logger) + device := device.NewDevice(dev.tun, conn.NewDefaultBind(), wgLogger) defer device.Close() @@ -124,6 +125,8 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error { return fmt.Errorf("error bringing link up: %w", err) } + logger.Info("wireguard device set up", zap.String("interface", interfaceName), zap.Stringer("address", dev.address)) + for { select { case <-ctx.Done(): @@ -133,24 +136,81 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error { case <-device.Wait(): return nil case peerEvent := <-peers.EventCh(): - cfg := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{ - { - PublicKey: peerEvent.PubKey, - Remove: peerEvent.Remove, - ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - *netaddr.IPPrefixFrom(peerEvent.Address, peerEvent.Address.BitLen()).IPNet(), - }, - }, - }, + if err := dev.handlePeerEvent(client, logger, peerEvent); err != nil { + return err + } + } + } +} + +func (dev *Device) checkDuplicateUpdate(client *wgctrl.Client, logger *zap.Logger, peerEvent PeerEvent) (bool, error) { + oldCfg, err := client.Device(interfaceName) + if err != nil { + return false, fmt.Errorf("error retrieving Wireguard configuration: %w", err) + } + + // check if this update can be skipped + pubKey := peerEvent.PubKey.String() + + for _, oldPeer := range oldCfg.Peers { + if oldPeer.PublicKey.String() == pubKey { + if len(oldPeer.AllowedIPs) != 1 { + break } - if err = client.ConfigureDevice(interfaceName, cfg); err != nil { - return fmt.Errorf("error configuring Wireguard peers: %w", err) + if prefix, ok := netaddr.FromStdIPNet(&oldPeer.AllowedIPs[0]); ok { + if prefix.IP() == peerEvent.Address { + // skip the update + logger.Info("skipping peer update", zap.String("public_key", pubKey)) + + return true, nil + } } + + break } } + + return false, nil +} + +func (dev *Device) handlePeerEvent(client *wgctrl.Client, logger *zap.Logger, peerEvent PeerEvent) error { + if !peerEvent.Remove { + skipEvent, err := dev.checkDuplicateUpdate(client, logger, peerEvent) + if err != nil { + return err + } + + if skipEvent { + return nil + } + } + + cfg := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerEvent.PubKey, + Remove: peerEvent.Remove, + }, + }, + } + + if !peerEvent.Remove { + cfg.Peers[0].ReplaceAllowedIPs = true + cfg.Peers[0].AllowedIPs = []net.IPNet{ + *netaddr.IPPrefixFrom(peerEvent.Address, peerEvent.Address.BitLen()).IPNet(), + } + + logger.Info("updating peer", zap.Stringer("public_key", peerEvent.PubKey), zap.Stringer("address", peerEvent.Address)) + } else { + logger.Info("removing peer", zap.Stringer("public_key", peerEvent.PubKey)) + } + + if err := client.ConfigureDevice(interfaceName, cfg); err != nil { + return fmt.Errorf("error configuring Wireguard peers: %w", err) + } + + return nil } // Close the device.