diff --git a/charts/raven-agent/templates/daemonset.yaml b/charts/raven-agent/templates/daemonset.yaml index 0b4653b..e0b5f7a 100644 --- a/charts/raven-agent/templates/daemonset.yaml +++ b/charts/raven-agent/templates/daemonset.yaml @@ -56,6 +56,8 @@ spec: - --vpn-bind-port={{.Values.vpn.tunnelAddr}} - --keep-alive-interval={{.Values.vpn.keepAliveInterval}} - --keep-alive-timeout={{.Values.vpn.keepAliveTimeout}} + - --sync-raven-rules={{.Values.sync.syncRule}} + - --sync-raven-rules-period={{.Values.sync.syncPeriod}} - --proxy-metric-bind-addr={{.Values.proxy.metricsBindAddr}} - --proxy-internal-secure-addr={{.Values.proxy.internalSecureAddr}} - --proxy-internal-insecure-addr={{.Values.proxy.internalInsecureAddr}} diff --git a/charts/raven-agent/values.yaml b/charts/raven-agent/values.yaml index ddd5287..e794cbe 100644 --- a/charts/raven-agent/values.yaml +++ b/charts/raven-agent/values.yaml @@ -59,6 +59,9 @@ containerEnv: secretKeyRef: key: vpn-connection-psk name: raven-agent-secret +sync: + syncRule: true + syncPeriod: 30m vpn: driver: libreswan @@ -86,4 +89,4 @@ proxy: metricsBindAddr: ":10266" rollingUpdate: - maxUnavailable: 5% \ No newline at end of file + maxUnavailable: 20% \ No newline at end of file diff --git a/cmd/agent/app/config/config.go b/cmd/agent/app/config/config.go index b9b4a3b..7667418 100644 --- a/cmd/agent/app/config/config.go +++ b/cmd/agent/app/config/config.go @@ -17,14 +17,18 @@ package config import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/rest" "sigs.k8s.io/controller-runtime/pkg/manager" ) // Config is the main context object for raven agent type Config struct { - NodeName string - NodeIP string + NodeName string + NodeIP string + SyncRules bool + SyncPeriod metav1.Duration + MetricsBindAddress string HealthProbeAddr string diff --git a/cmd/agent/app/options/options.go b/cmd/agent/app/options/options.go index a3c1f78..016dbd0 100644 --- a/cmd/agent/app/options/options.go +++ b/cmd/agent/app/options/options.go @@ -8,6 +8,7 @@ import ( "regexp" "strconv" "strings" + "time" "github.com/spf13/pflag" v1 "k8s.io/api/core/v1" @@ -50,6 +51,8 @@ type AgentOptions struct { Kubeconfig string MetricsBindAddress string HealthProbeAddr string + SyncRules bool + SyncPeriod metav1.Duration } type TunnelOptions struct { @@ -91,6 +94,12 @@ func (o *AgentOptions) Validate() error { } } } + if o.SyncPeriod.Duration < time.Minute { + o.SyncPeriod.Duration = time.Minute + } + if o.SyncPeriod.Duration > 24*time.Hour { + o.SyncPeriod.Duration = 24 * time.Hour + } return nil } @@ -103,6 +112,8 @@ func (o *AgentOptions) AddFlags(fs *pflag.FlagSet) { fs.StringVar(&o.RouteDriver, "route-driver", o.RouteDriver, `The Route driver name. (default "vxlan")`) fs.StringVar(&o.MetricsBindAddress, "metric-bind-addr", o.MetricsBindAddress, `Binding address of tunnel metrics. (default ":10265")`) fs.StringVar(&o.HealthProbeAddr, "health-probe-addr", o.HealthProbeAddr, `The address the healthz/readyz endpoint binds to.. (default ":10275")`) + fs.BoolVar(&o.SyncRules, "sync-raven-rules", true, "Whether to synchronize raven rules regularly") + fs.DurationVar(&o.SyncPeriod.Duration, "sync-raven-rules-period", 10*time.Minute, "The period for reconciling routes created for nodes by cloud provider. The minimum value is 1 minute and the maximum value is 24 hour") fs.StringVar(&o.VPNPort, "vpn-bind-port", o.VPNPort, `Binding port of vpn. (default ":4500")`) fs.BoolVar(&o.NATTraversal, "nat-traversal", o.NATTraversal, `Enable NAT Traversal or not. (default "false")`) @@ -141,8 +152,10 @@ func (o *AgentOptions) Config() (*config.Config, error) { } cfg = restclient.AddUserAgent(cfg, "raven-agent-ds") c := &config.Config{ - NodeName: o.NodeName, - NodeIP: o.NodeIP, + NodeName: o.NodeName, + NodeIP: o.NodeIP, + SyncRules: o.SyncRules, + SyncPeriod: o.SyncPeriod, } c.KubeConfig = cfg c.MetricsBindAddress = resolveAddress(c.MetricsBindAddress, resolveLocalHost(), strconv.Itoa(DefaultTunnelMetricsPort)) diff --git a/cmd/agent/app/start.go b/cmd/agent/app/start.go index e542bab..fe23038 100644 --- a/cmd/agent/app/start.go +++ b/cmd/agent/app/start.go @@ -19,15 +19,17 @@ package app import ( "context" "fmt" + "sync" + "time" "github.com/lorenzosaino/go-sysctl" + "github.com/spf13/cobra" "k8s.io/klog/v2" "github.com/openyurtio/raven/cmd/agent/app/config" "github.com/openyurtio/raven/cmd/agent/app/options" ravenengine "github.com/openyurtio/raven/pkg/engine" "github.com/openyurtio/raven/pkg/features" - "github.com/spf13/cobra" ) // NewRavenAgentCommand creates a new raven agent command @@ -70,6 +72,15 @@ func Run(ctx context.Context, cfg *config.CompletedConfig) error { } klog.Info("engine successfully start") engine.Start() + var wg sync.WaitGroup + wg.Add(1) + go func() { + <-ctx.Done() + time.Sleep(time.Second) + engine.Cleanup() + wg.Done() + }() + wg.Wait() return nil } diff --git a/cmd/agent/main.go b/cmd/agent/main.go index e1b2733..4a37a8e 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -32,7 +32,7 @@ var GitCommit string func main() { klog.InitFlags(nil) defer klog.Flush() - rand.Seed(time.Now().UnixNano()) + rand.NewSource(time.Now().UnixNano()) klog.Infof("component: %s, git commit: %s\n", "raven-agent-ds", GitCommit) cmd := app.NewRavenAgentCommand(server.SetupSignalContext()) cmd.Flags().AddGoFlagSet(flag.CommandLine) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index c06a8a9..70acbfe 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -4,6 +4,7 @@ import ( "context" "time" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/util/workqueue" @@ -22,13 +23,16 @@ import ( ) type Engine struct { - nodeName string - nodeIP string - context context.Context - manager manager.Manager - client client.Client - option *Option - queue workqueue.RateLimitingInterface + nodeName string + nodeIP string + syncRules bool + syncPeriod metav1.Duration + + context context.Context + manager manager.Manager + client client.Client + option *Option + queue workqueue.RateLimitingInterface tunnel *TunnelEngine proxy *ProxyEngine @@ -36,12 +40,14 @@ type Engine struct { func NewEngine(ctx context.Context, cfg *config.Config) (*Engine, error) { engine := &Engine{ - nodeName: cfg.NodeName, - nodeIP: cfg.NodeIP, - manager: cfg.Manager, - context: ctx, - option: NewEngineOption(), - queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "raven"), + nodeName: cfg.NodeName, + nodeIP: cfg.NodeIP, + syncRules: cfg.SyncRules, + syncPeriod: cfg.SyncPeriod, + manager: cfg.Manager, + context: ctx, + option: NewEngineOption(), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "raven"), } err := ctrl.NewControllerManagedBy(engine.manager). For(&v1beta1.Gateway{}, builder.WithPredicates(predicate.Funcs{ @@ -53,7 +59,7 @@ func NewEngine(ctx context.Context, cfg *config.Config) (*Engine, error) { return reconcile.Result{}, nil })) if err != nil { - klog.Errorf(utils.FormatRavenEngine("fail to new controller with manager, error %s", err.Error())) + klog.Errorf("fail to new controller with manager, error %s", err.Error()) return engine, err } engine.client = engine.manager.GetClient() @@ -66,7 +72,7 @@ func NewEngine(ctx context.Context, cfg *config.Config) (*Engine, error) { } err = engine.tunnel.InitDriver() if err != nil { - klog.Errorf(utils.FormatRavenEngine("fail to init tunnel driver, error %s", err.Error())) + klog.Errorf("fail to init tunnel driver, error %s", err.Error()) return engine, err } @@ -90,9 +96,12 @@ func (e *Engine) Start() { klog.ErrorS(err, "failed to start engine controller") } }() + go wait.Until(e.worker, time.Second, e.context.Done()) - <-e.context.Done() - e.cleanup() + + if e.syncRules { + go wait.Until(e.regularSync, e.syncPeriod.Duration, e.context.Done()) + } } func (e *Engine) worker() { @@ -110,19 +119,29 @@ func (e *Engine) processNextWorkItem() bool { return false } defer e.queue.Done(gw) - e.findLocalGateway() - err := e.tunnel.Handler() + err := e.sync() if err != nil { e.handleEventErr(err, gw) } - e.option.SetTunnelStatus(e.tunnel.Status()) + return true +} - err = e.proxy.Handler() +func (e *Engine) sync() error { + e.findLocalGateway() + err := e.proxy.Handler() if err != nil { - e.handleEventErr(err, gw) + return err + } + err = e.tunnel.Handler() + if err != nil { + return err } + e.option.SetTunnelStatus(e.tunnel.Status()) + return nil +} - return true +func (e *Engine) regularSync() { + e.queue.Add(&v1beta1.Gateway{ObjectMeta: metav1.ObjectMeta{Name: "gw-sync"}}) } func (e *Engine) findLocalGateway() { @@ -144,12 +163,9 @@ func (e *Engine) findLocalGateway() { } } -func (e *Engine) cleanup() { +func (e *Engine) Cleanup() { if e.option.GetTunnelStatus() { - err := e.tunnel.CleanupDriver() - if err != nil { - klog.Errorf(utils.FormatRavenEngine("failed to cleanup tunnel driver, error %s", err.Error())) - } + e.tunnel.CleanupDriver() } if e.option.GetProxyStatus() { e.proxy.stop() @@ -163,18 +179,18 @@ func (e *Engine) handleEventErr(err error, gw *v1beta1.Gateway) { } if e.queue.NumRequeues(gw) < utils.MaxRetries { - klog.Info(utils.FormatRavenEngine("error syncing event %s: %s", gw.GetName(), err.Error())) + klog.Infof("error syncing event %s: %s", gw.GetName(), err.Error()) e.queue.AddRateLimited(gw) return } - klog.Info(utils.FormatRavenEngine("dropping event %s out of the queue: %s", gw.GetName(), err.Error())) + klog.Infof("dropping event %s out of the queue: %s", gw.GetName(), err.Error()) e.queue.Forget(gw) } func (e *Engine) addGateway(evt event.CreateEvent) bool { gw, ok := evt.Object.(*v1beta1.Gateway) if ok { - klog.InfoS(utils.FormatRavenEngine("adding gateway %s", gw.GetName())) + klog.Infof("adding gateway %s", gw.GetName()) e.queue.Add(gw.DeepCopy()) } return ok @@ -187,10 +203,8 @@ func (e *Engine) updateGateway(evt event.UpdateEvent) bool { if ok1 && ok2 { if oldGw.ResourceVersion != newGw.ResourceVersion { update = true - klog.InfoS(utils.FormatRavenEngine("updating gateway, %s", newGw.GetName())) + klog.Infof("updating gateway, %s", newGw.GetName()) e.queue.Add(newGw.DeepCopy()) - } else { - klog.InfoS(utils.FormatRavenEngine("skip handle update gateway"), klog.KObj(newGw)) } } return update @@ -199,7 +213,7 @@ func (e *Engine) updateGateway(evt event.UpdateEvent) bool { func (e *Engine) deleteGateway(evt event.DeleteEvent) bool { gw, ok := evt.Object.(*v1beta1.Gateway) if ok { - klog.InfoS(utils.FormatRavenEngine("deleting gateway, %s", gw.GetName())) + klog.Infof("deleting gateway, %s", gw.GetName()) e.queue.Add(gw.DeepCopy()) } return ok diff --git a/pkg/engine/proxy.go b/pkg/engine/proxy.go index 87058b3..8f3f016 100644 --- a/pkg/engine/proxy.go +++ b/pkg/engine/proxy.go @@ -79,7 +79,7 @@ func (p *ProxyEngine) Handler() error { srcAddr := getSrcAddressForProxyServer(p.client, p.nodeName) err = p.startProxyServer() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy server, error %s", err.Error())) + klog.Errorf("failed to start proxy server, error %s", err.Error()) return err } p.serverLocalEndpoints = srcAddr @@ -93,7 +93,7 @@ func (p *ProxyEngine) Handler() error { time.Sleep(2 * time.Second) err = p.startProxyServer() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy server, error %s", err.Error())) + klog.Errorf("failed to start proxy server, error %s", err.Error()) return err } p.serverLocalEndpoints = srcAddr @@ -106,7 +106,7 @@ func (p *ProxyEngine) Handler() error { case StartType: err = p.startProxyClient() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy client, error %s", err.Error())) + klog.Errorf("failed to start proxy client, error %s", err.Error()) return err } case StopType: @@ -114,7 +114,7 @@ func (p *ProxyEngine) Handler() error { case RestartType: dstAddr := getDestAddressForProxyClient(p.client, p.localGateway) if len(dstAddr) < 1 { - klog.Infoln(utils.FormatProxyClient("dest address is empty, will not connected it")) + klog.Infoln("dest address is empty, will not connected it") return nil } if strings.Join(p.clientRemoteEndpoints, ",") != strings.Join(dstAddr, ",") { @@ -122,7 +122,7 @@ func (p *ProxyEngine) Handler() error { time.Sleep(2 * time.Second) err = p.startProxyClient() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy client, error %s", err.Error())) + klog.Errorf("failed to start proxy client, error %s", err.Error()) return err } } @@ -133,7 +133,7 @@ func (p *ProxyEngine) Handler() error { } func (p *ProxyEngine) startProxyServer() error { - klog.Infoln(utils.FormatProxyServer("start raven l7 proxy server")) + klog.Infoln("start raven l7 proxy server") if p.localGateway == nil { return fmt.Errorf("unknown gateway for node %s, can not start proxy server", p.nodeName) } @@ -164,7 +164,7 @@ func (p *ProxyEngine) startProxyServer() error { } func (p *ProxyEngine) stopProxyServer() { - klog.Infoln(utils.FormatProxyServer("Stop raven l7 proxy server")) + klog.Infoln("Stop raven l7 proxy server") cancel := p.proxyCtx.GetServerCancelFunc() cancel() p.proxyOption.SetServerStatus(false) @@ -172,11 +172,11 @@ func (p *ProxyEngine) stopProxyServer() { } func (p *ProxyEngine) startProxyClient() error { - klog.Infoln(utils.FormatProxyClient("start raven l7 proxy client")) + klog.Infoln("start raven l7 proxy client") var err error dstAddr := getDestAddressForProxyClient(p.client, p.localGateway) if len(dstAddr) < 1 { - klog.Infoln(utils.FormatProxyClient("dest address is empty, will not connected it")) + klog.Infoln("dest address is empty, will not connected it") return nil } p.clientRemoteEndpoints = dstAddr @@ -195,13 +195,14 @@ func (p *ProxyEngine) startProxyClient() error { err = pc.Start(ctx) if err != nil { klog.Errorf("failed to start proxy client, error %s", err.Error()) + return err } p.proxyOption.SetClientStatus(true) return nil } func (p *ProxyEngine) stopProxyClient() { - klog.Infoln(utils.FormatProxyClient("stop raven l7 proxy client")) + klog.Infoln("stop raven l7 proxy client") cancel := p.proxyCtx.GetClientCancelFunc() cancel() p.proxyOption.SetClientStatus(false) diff --git a/pkg/engine/tunnel.go b/pkg/engine/tunnel.go index 6bae19c..3255700 100644 --- a/pkg/engine/tunnel.go +++ b/pkg/engine/tunnel.go @@ -20,13 +20,14 @@ import ( "context" "fmt" "net" - "reflect" "strconv" + "time" "github.com/EvilSuperstars/go-cidrman" v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/util/retry" "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" @@ -51,9 +52,8 @@ type TunnelEngine struct { routeDriver routedriver.Driver vpnDriver vpndriver.Driver - nodeInfos map[types.NodeName]*v1beta1.NodeInfo - network *types.Network - lastSeenNetwork *types.Network + nodeInfos map[types.NodeName]*v1beta1.NodeInfo + network *types.Network } func (c *TunnelEngine) InitDriver() error { @@ -74,20 +74,24 @@ func (c *TunnelEngine) InitDriver() error { if err != nil { return fmt.Errorf("fail to initialize vpn driver: %s, %s", c.config.Tunnel.VPNDriver, err) } - klog.Info(utils.FormatTunnel("route driver %s and vpn driver %s are initialized", c.config.Tunnel.RouteDriver, c.config.Tunnel.VPNDriver)) + klog.Infof("route driver %s and vpn driver %s are initialized", c.config.Tunnel.RouteDriver, c.config.Tunnel.VPNDriver) return nil } -func (c *TunnelEngine) CleanupDriver() error { - err := c.routeDriver.Cleanup() - if err != nil { - return fmt.Errorf("fail to cleanup route driver: %s", err.Error()) - } - err = c.vpnDriver.Cleanup() - if err != nil { - return fmt.Errorf("fail to cleanup vpn driver: %s", err.Error()) - } - return nil +func (c *TunnelEngine) CleanupDriver() { + _ = wait.PollImmediate(time.Second, 5*time.Second, func() (done bool, err error) { + err = c.vpnDriver.Cleanup() + if err != nil { + klog.Errorf("fail to cleanup vpn driver: %s", err.Error()) + return false, nil + } + err = c.routeDriver.Cleanup() + if err != nil { + klog.Errorf("fail to cleanup route driver: %s", err.Error()) + return false, nil + } + return true, nil + }) } func (c *TunnelEngine) Status() bool { @@ -105,7 +109,7 @@ func (c *TunnelEngine) Status() bool { func (c *TunnelEngine) Handler() error { if c.config.Tunnel.NATTraversal { if err := c.checkNatCapability(); err != nil { - klog.Errorf(utils.FormatTunnel("fail to check the capability of NAT, error %s", err.Error())) + klog.Errorf("fail to check the capability of NAT, error %s", err.Error()) return err } } @@ -154,26 +158,18 @@ func (c *TunnelEngine) Handler() error { } c.syncGateway(gw) } - if reflect.DeepEqual(c.network, c.lastSeenNetwork) { - klog.Info("network not changed, skip to process") - return nil - } nw := c.network.Copy() klog.InfoS("applying network", "localEndpoint", nw.LocalEndpoint, "remoteEndpoint", nw.RemoteEndpoints) err = c.vpnDriver.Apply(nw, c.routeDriver.MTU) if err != nil { - klog.ErrorS(err, "error apply vpn driver") + klog.Errorf("error apply vpn driver, error %s", err.Error()) return err } err = c.routeDriver.Apply(nw, c.vpnDriver.MTU) if err != nil { - klog.ErrorS(err, "error apply route driver") + klog.Errorf("error apply route driver, error %s", err.Error()) return err } - - // Only update lastSeenNetwork when all operations succeeded. - c.lastSeenNetwork = c.network - return nil } diff --git a/pkg/networkengine/util/ipset/ipset.go b/pkg/networkengine/util/ipset/ipset.go index 66e34e6..ee844eb 100644 --- a/pkg/networkengine/util/ipset/ipset.go +++ b/pkg/networkengine/util/ipset/ipset.go @@ -33,14 +33,26 @@ type IPSetInterface interface { Del(entry *netlink.IPSetEntry) error Flush() error Destroy() error + Key(entry *netlink.IPSetEntry) string } +var DefaultKeyFunc = EntryKey + type ipSetWrapper struct { setName string + setType string + keyFunc func(setEntry *netlink.IPSetEntry) string +} + +type IpsetWrapperOption struct { + KeyFunc func(setEntry *netlink.IPSetEntry) string } -func New(setName string) (IPSetInterface, error) { - err := netlink.IpsetCreate(setName, "hash:net", netlink.IpsetCreateOptions{ +func New(setName, setTypeName string, options IpsetWrapperOption) (IPSetInterface, error) { + if options.KeyFunc == nil { + options.KeyFunc = DefaultKeyFunc + } + err := netlink.IpsetCreate(setName, setTypeName, netlink.IpsetCreateOptions{ Replace: true, }) if err != nil { @@ -50,7 +62,7 @@ func New(setName string) (IPSetInterface, error) { if klog.V(5).Enabled() { klog.V(5).InfoS("netlink.IpsetCreate succeeded", "setName", setName) } - return &ipSetWrapper{setName}, nil + return &ipSetWrapper{setName, setTypeName, options.KeyFunc}, nil } func (i *ipSetWrapper) List() (*netlink.IPSetResult, error) { @@ -72,11 +84,11 @@ func (i *ipSetWrapper) Name() string { func (i *ipSetWrapper) Add(entry *netlink.IPSetEntry) (err error) { err = netlink.IpsetAdd(i.Name(), entry) if err != nil { - klog.ErrorS(err, "error on netlink.IpsetAdd", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.ErrorS(err, "error on netlink.IpsetAdd", "setName", i.Name(), "entry", i.Key(entry)) return } if klog.V(5).Enabled() { - klog.V(5).InfoS("netlink.IpsetAdd succeeded", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.V(5).InfoS("netlink.IpsetAdd succeeded", "setName", i.Name(), "entry", i.Key(entry)) } return } @@ -84,11 +96,11 @@ func (i *ipSetWrapper) Add(entry *netlink.IPSetEntry) (err error) { func (i *ipSetWrapper) Del(entry *netlink.IPSetEntry) (err error) { err = netlink.IpsetDel(i.Name(), entry) if err != nil { - klog.ErrorS(err, "error on netlink.IpsetDel", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.ErrorS(err, "error on netlink.IpsetDel", "setName", i.Name(), "entry", i.Key) return } if klog.V(5).Enabled() { - klog.V(5).InfoS("netlink.IpsetDel succeeded", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.V(5).InfoS("netlink.IpsetDel succeeded", "setName", i.Name(), "entry", i.Key(entry)) } return } @@ -117,6 +129,10 @@ func (i *ipSetWrapper) Destroy() (err error) { return } -func SetEntryKey(setEntry *netlink.IPSetEntry) string { +func (i *ipSetWrapper) Key(entry *netlink.IPSetEntry) string { + return i.keyFunc(entry) +} + +func EntryKey(setEntry *netlink.IPSetEntry) string { return fmt.Sprintf("%s/%d", setEntry.IP.String(), setEntry.CIDR) } diff --git a/pkg/networkengine/util/netlink/netlink.go b/pkg/networkengine/util/netlink/netlink.go index 2c05366..a503f81 100644 --- a/pkg/networkengine/util/netlink/netlink.go +++ b/pkg/networkengine/util/netlink/netlink.go @@ -41,6 +41,7 @@ var ( RuleDel = ruleDel XfrmPolicyFlush = xfrmPolicyFlush + XfrmStateFlush = xfrmStateFlush NeighAdd = neighAdd NeighReplace = neighReplace @@ -127,6 +128,16 @@ func xfrmPolicyFlush() (err error) { return nil } +func xfrmStateFlush() (err error) { + err = netlink.XfrmStateFlush(0) + if err != nil { + klog.ErrorS(err, "error on netlink.XfrmStateFlush") + return + } + klog.V(5).InfoS("netlink.XfrmStateFlush succeeded") + return nil +} + func ruleListFiltered(family int, filter *netlink.Rule, filterMask uint64) (rules []netlink.Rule, err error) { rules, err = netlink.RuleListFiltered(family, filter, filterMask) if err != nil { diff --git a/pkg/networkengine/util/utils.go b/pkg/networkengine/util/utils.go index 23b0add..1b9c834 100644 --- a/pkg/networkengine/util/utils.go +++ b/pkg/networkengine/util/utils.go @@ -21,7 +21,6 @@ package networkutil import ( "fmt" - "net" "syscall" "github.com/vdobler/ht/errorlist" @@ -32,11 +31,6 @@ import ( netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" ) -var ( - AllZeroMAC = net.HardwareAddr{0, 0, 0, 0, 0, 0} - AllZeroAddress = "0.0.0.0/0" -) - func NewRavenRule(rulePriority int, routeTableID int) *netlink.Rule { rule := netlink.NewRule() rule.Priority = rulePriority @@ -94,7 +88,7 @@ func ListIPSetOnNode(set ipsetutil.IPSetInterface) (map[string]*netlink.IPSetEnt } ro := make(map[string]*netlink.IPSetEntry) for i := range info.Entries { - ro[ipsetutil.SetEntryKey(&info.Entries[i])] = &info.Entries[i] + ro[set.Key(&info.Entries[i])] = &info.Entries[i] } return ro, nil } @@ -114,7 +108,11 @@ func ApplyRules(current, desired map[string]*netlink.Rule) (err error) { } } // add expect ip rules - for _, v := range desired { + for k, v := range desired { + _, ok := current[k] + if ok { + continue + } klog.InfoS("adding rule", "src", v.Src, "lookup", v.Table) err = netlinkutil.RuleAdd(v) errList = errList.Append(err) diff --git a/pkg/networkengine/vpndriver/ipset/ipset.go b/pkg/networkengine/vpndriver/ipset/ipset.go new file mode 100644 index 0000000..f3fdfd2 --- /dev/null +++ b/pkg/networkengine/vpndriver/ipset/ipset.go @@ -0,0 +1,129 @@ +/* +Copyright 2023 The OpenYurt Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ipset + +import ( + "fmt" + "net" + + "github.com/EvilSuperstars/go-cidrman" + "github.com/vdobler/ht/errorlist" + "github.com/vishvananda/netlink" + "k8s.io/klog/v2" + + ipsetutil "github.com/openyurtio/raven/pkg/networkengine/util/ipset" + "github.com/openyurtio/raven/pkg/types" +) + +const ( + RavenSkipNatSet = "raven-skip-nat-set" + RavenSkipNatSetType = "hash:net,net" +) + +var KeyFunc = func(entry *netlink.IPSetEntry) string { + return fmt.Sprintf("%s/%d-%s/%d", entry.IP.String(), entry.CIDR, entry.IP2.String(), entry.CIDR2) +} + +func IsGatewayRole(network *types.Network, nodeName types.NodeName) bool { + return network != nil && + network.LocalEndpoint != nil && + network.LocalEndpoint.NodeName == nodeName +} + +func IsCentreGatewayRole(centralGw *types.Endpoint, localNodeName types.NodeName) bool { + return centralGw != nil && centralGw.NodeName == localNodeName +} + +func CalIPSetOnNode(network *types.Network, centralGw *types.Endpoint, nodeName types.NodeName, ipset ipsetutil.IPSetInterface) map[string]*netlink.IPSetEntry { + set := make(map[string]*netlink.IPSetEntry) + subnets := make([]string, 0) + for _, v := range network.RemoteNodeInfo { + nodeInfo := network.RemoteNodeInfo[types.NodeName(v.NodeName)] + if nodeInfo == nil { + klog.Errorf("node %s not found in RemoteNodeInfo", v.NodeName) + continue + } + subnets = append(subnets, nodeInfo.Subnets...) + } + var err error + subnets, err = cidrman.MergeCIDRs(subnets) + if err != nil { + return set + } + if IsCentreGatewayRole(centralGw, nodeName) { + subnets = append(subnets, network.LocalEndpoint.Subnets...) + for _, srcCIDR := range subnets { + _, ipNet, err := net.ParseCIDR(srcCIDR) + if err != nil { + klog.Errorf("parse node subnet %s error %s", srcCIDR, err.Error()) + continue + } + ones, _ := ipNet.Mask.Size() + entry := &netlink.IPSetEntry{ + IP: ipNet.IP, + CIDR: uint8(ones), + IP2: ipNet.IP, + CIDR2: uint8(ones), + Replace: true, + } + set[ipset.Key(entry)] = entry + } + } else { + for _, localCIDR := range network.LocalEndpoint.Subnets { + _, localIPNet, err := net.ParseCIDR(localCIDR) + if err != nil { + klog.Errorf("parse node subnet %s error %s", localCIDR, err.Error()) + continue + } + localOnes, _ := localIPNet.Mask.Size() + for _, remoteCIDR := range subnets { + _, remoteIPNet, err := net.ParseCIDR(remoteCIDR) + if err != nil { + klog.Errorf("parse node subnet %s error %s", remoteCIDR, err.Error()) + continue + } + remoteOnes, _ := remoteIPNet.Mask.Size() + entry := &netlink.IPSetEntry{ + IP: localIPNet.IP, + CIDR: uint8(localOnes), + IP2: remoteIPNet.IP, + CIDR2: uint8(remoteOnes), + Replace: true, + } + set[ipset.Key(entry)] = entry + } + } + } + return set +} + +func CleanupRavenSkipNATIPSet() error { + errList := errorlist.List{} + ipset, err := ipsetutil.New(RavenSkipNatSet, RavenSkipNatSetType, ipsetutil.IpsetWrapperOption{}) + if err != nil { + errList = errList.Append(fmt.Errorf("error ensure ip set %s: %s", RavenSkipNatSet, err)) + } + err = ipset.Flush() + if err != nil { + errList = errList.Append(fmt.Errorf("error flushing ipset: %s", err)) + } + err = ipset.Destroy() + if err != nil { + errList = errList.Append(fmt.Errorf("error destroying ipset: %s", err)) + } + return errList.AsError() +} diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan.go b/pkg/networkengine/vpndriver/libreswan/libreswan.go index 50395b0..e94db6c 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan.go @@ -17,10 +17,14 @@ package libreswan import ( + "bufio" + "bytes" "fmt" "os" "os/exec" + "regexp" "strconv" + "strings" "syscall" "time" @@ -28,9 +32,12 @@ import ( "k8s.io/klog/v2" "github.com/openyurtio/raven/cmd/agent/app/config" + networkutil "github.com/openyurtio/raven/pkg/networkengine/util" + ipsetutil "github.com/openyurtio/raven/pkg/networkengine/util/ipset" iptablesutil "github.com/openyurtio/raven/pkg/networkengine/util/iptables" netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" "github.com/openyurtio/raven/pkg/networkengine/vpndriver" + vpndriveripset "github.com/openyurtio/raven/pkg/networkengine/vpndriver/ipset" "github.com/openyurtio/raven/pkg/types" "github.com/openyurtio/raven/pkg/utils" ) @@ -40,12 +47,16 @@ const ( // DriverName specifies name of libreswan VPN backend driver. DriverName = "libreswan" + + IKESAESTABLISHED = "STATE_V2_ESTABLISHED_IKE_SA" + ChILDSAESTABLISHED = "STATE_V2_ESTABLISHED_CHILD_SA" ) var _ vpndriver.Driver = (*libreswan)(nil) // can be modified for testing. var whackCmd = whackCmdFn +var ipsecCmd = ipsecCmdFn var findCentralGw = vpndriver.FindCentralGwFn var enableCreateEdgeConnection = vpndriver.EnableCreateEdgeConnection @@ -58,11 +69,11 @@ const ( ) type libreswan struct { - relayConnections map[string]*vpndriver.Connection - edgeConnections map[string]*vpndriver.Connection + connections map[string]bool nodeName types.NodeName centralGw *types.Endpoint iptables iptablesutil.IPTablesInterface + ipset ipsetutil.IPSetInterface listenPort string keepaliveInterval int keepaliveTimeout int @@ -73,6 +84,11 @@ func (l *libreswan) Init() (err error) { if err != nil { return err } + l.ipset, err = ipsetutil.New(vpndriveripset.RavenSkipNatSet, vpndriveripset.RavenSkipNatSetType, ipsetutil.IpsetWrapperOption{}) + if err != nil { + return err + } + // Ensure secrets file _, err = os.Stat(SecretFile) if err == nil { @@ -95,8 +111,7 @@ func (l *libreswan) Init() (err error) { func New(cfg *config.Config) (vpndriver.Driver, error) { return &libreswan{ - relayConnections: make(map[string]*vpndriver.Connection), - edgeConnections: make(map[string]*vpndriver.Connection), + connections: make(map[string]bool), nodeName: types.NodeName(cfg.NodeName), listenPort: cfg.Tunnel.VPNPort, keepaliveInterval: cfg.Tunnel.KeepAliveInterval, @@ -110,12 +125,13 @@ func (l *libreswan) Apply(network *types.Network, routeDriverMTUFn func(*types.N return l.Cleanup() } if network.LocalEndpoint.NodeName != l.nodeName { - klog.Infof(utils.FormatTunnel("the current node is not gateway node, cleaning vpn connections")) + klog.Infof("the current node is not gateway node, cleaning vpn connections") return l.Cleanup() } - if err := l.createConnections(network); err != nil { - return fmt.Errorf("error create VPN tunnels: %v", err) + l.centralGw = findCentralGw(network) + if err := l.ensureConnections(network); err != nil { + return fmt.Errorf("error ensure VPN tunnels: %s", err.Error()) } return nil @@ -176,136 +192,181 @@ func (l *libreswan) getEndpointResolver(network *types.Network) func(centralGw, } } -func (l *libreswan) createConnections(network *types.Network) error { - l.centralGw = findCentralGw(network) +func (l *libreswan) ensureConnections(network *types.Network) error { + defer func() { + // wait connection is established + time.Sleep(5 * time.Second) + }() + + l.connections = currentConnections() + if err := l.deleteUnavailableConn(); err != nil { + return fmt.Errorf("delete unavailabel connections error %s", err.Error()) + } desiredEdgeConns, desiredRelayConns := l.computeDesiredConnections(network) - if len(desiredEdgeConns) == 0 && len(desiredRelayConns) == 0 { - klog.Infof(utils.FormatTunnel("no desired connections, cleaning vpn connections")) - return l.Cleanup() + klog.Infof("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns) + + if err := l.deleteUndesiredConn(desiredEdgeConns, desiredRelayConns); err != nil { + return fmt.Errorf("ensure delete undesired connections error %s", err.Error()) } - klog.Infof(utils.FormatTunnel("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns)) + if err := l.ensureEdgeConnections(desiredEdgeConns); err != nil { + return fmt.Errorf("ensure delete edge-edge connections error %s", err.Error()) + } - if err := l.createEdgeConnections(desiredEdgeConns); err != nil { - return err + if err := l.ensureRelayConnections(desiredRelayConns); err != nil { + return fmt.Errorf("ensure delete cloud-edge connections error %s", err.Error()) } - if err := l.createRelayConnections(desiredRelayConns); err != nil { - return err + + if err := l.ensureRavenSkipNAT(network); err != nil { + return fmt.Errorf("ensure raven skip nat error %s", err.Error()) } return nil } -func (l *libreswan) createEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { - if len(desiredEdgeConns) == 0 { - klog.Infof("no desired edge connections") - return nil +func currentConnections() map[string]bool { + connections := make(map[string]bool) + reg := regexp.MustCompile(`"([^"]+)"`) + out, err := ipsecCmd("auto", "--status") + if err != nil { + return connections + } + foundConnectionList := false + scanner := bufio.NewScanner(out) + for scanner.Scan() { + line := scanner.Text() + if strings.Contains(line, "Connection list") { + foundConnectionList = true + continue + } + if foundConnectionList { + matches := reg.FindAllStringSubmatch(line, -1) + for _, match := range matches { + if len(match) > 1 { + connections[match[1]] = false + } + } + } } + for k := range connections { + out, err = ipsecCmd("whack", "--showstates") + if err != nil { + continue + } + foundIKESAEstablished := false + foundChildSAEstablished := false + scanner = bufio.NewScanner(out) + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, k) { + continue + } + if strings.Contains(line, IKESAESTABLISHED) { + foundIKESAEstablished = true + } + if strings.Contains(line, ChILDSAESTABLISHED) { + foundChildSAEstablished = true + } + } + if foundIKESAEstablished && foundChildSAEstablished { + connections[k] = true + } + } + return connections +} +func (l *libreswan) deleteUnavailableConn() error { errList := errorlist.List{} - - // remove unwanted connections - for connName := range l.edgeConnections { - if _, ok := desiredEdgeConns[connName]; !ok { + for connName, established := range l.connections { + if !established { err := l.whackDelConnection(connName) if err != nil { errList = errList.Append(err) - klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) + klog.ErrorS(err, "error delete unavailable connection", "connectionName", connName) continue } - delete(l.edgeConnections, connName) + delete(l.connections, connName) } } - - // add new connections - for name, connection := range desiredEdgeConns { - err := l.connectToEdgeEndpoint(name, connection) - errList = errList.Append(err) - } - return errList.AsError() } -func (l *libreswan) createRelayConnections(desiredRelayConns map[string]*vpndriver.Connection) error { - if len(desiredRelayConns) == 0 { - klog.Infof("no desired relay connections") - return nil - } - +func (l *libreswan) deleteUndesiredConn(desiredEdgeConns, desiredRelayConns map[string]*vpndriver.Connection) error { errList := errorlist.List{} - - // remove unwanted connections - for connName := range l.relayConnections { - if _, ok := desiredRelayConns[connName]; !ok { + desireConn := make(map[string]struct{}) + for k := range desiredEdgeConns { + desireConn[k] = struct{}{} + } + for k := range desiredRelayConns { + desireConn[k] = struct{}{} + } + for connName := range l.connections { + if _, ok := desireConn[connName]; !ok { err := l.whackDelConnection(connName) if err != nil { errList = errList.Append(err) - klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) + klog.ErrorS(err, "error delete undesired connection", "connectionName", connName) continue } - if l.centralGw.NodeName == l.nodeName { - if conn, ok := l.relayConnections[connName]; ok && conn != nil { - err := l.deleteRavenSkipNAT(conn) - if err != nil { - errList = errList.Append(err) - } - } - } - delete(l.relayConnections, connName) + delete(l.connections, connName) } } + return errList.AsError() +} + +func (l *libreswan) ensureEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { + errList := errorlist.List{} + for name, connection := range desiredEdgeConns { + err := l.connectToEdgeEndpoint(name, connection) + errList = errList.Append(err) + } + return errList.AsError() +} - // add new connections +func (l *libreswan) ensureRelayConnections(desiredRelayConns map[string]*vpndriver.Connection) error { + errList := errorlist.List{} for name, connection := range desiredRelayConns { err := l.connectToEndpoint(name, connection) errList = errList.Append(err) - if l.centralGw.NodeName == l.nodeName { - err = l.ensureRavenSkipNAT(connection) - if err != nil { - errList = errList.Append(err) - } - } } - return errList.AsError() } -func (l *libreswan) ensureRavenSkipNAT(connection *vpndriver.Connection) errorlist.List { - errList := errorlist.List{} - for _, subnet := range l.centralGw.Subnets { - if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { - return errList - } +func (l *libreswan) ensureRavenSkipNAT(network *types.Network) error { + if !vpndriveripset.IsGatewayRole(network, l.nodeName) { + klog.Infof("node %s is not gateway, skip add skip nat", l.nodeName) + return nil } - // for raven skip nat - if err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err)) + + // The desired and current ipset entries calculated from given network. + // The key is ip set entry + var err error + l.ipset, err = ipsetutil.New(vpndriveripset.RavenSkipNatSet, vpndriveripset.RavenSkipNatSetType, ipsetutil.IpsetWrapperOption{KeyFunc: vpndriveripset.KeyFunc}) + if err != nil { + return fmt.Errorf("error ensure ipset %s, type %s", vpndriveripset.RavenSkipNatSet, vpndriveripset.RavenSkipNatSetType) } - if err := l.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { - errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err)) + currentSet, err := networkutil.ListIPSetOnNode(l.ipset) + if err != nil { + return fmt.Errorf("error listing ip set %s on node: %s", l.ipset.Name(), err.Error()) } - if err := l.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT"); err != nil { - errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err)) + desiredSet := vpndriveripset.CalIPSetOnNode(network, l.centralGw, l.nodeName, l.ipset) + err = networkutil.ApplyIPSet(l.ipset, currentSet, desiredSet) + if err != nil { + return fmt.Errorf("error applying ip set: %s", err) } - return errList -} -func (l *libreswan) deleteRavenSkipNAT(connection *vpndriver.Connection) errorlist.List { - errList := errorlist.List{} - err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) - if err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + // for raven skip nat + if err = l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) } - for _, subnet := range l.centralGw.Subnets { - if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { - return errList - } + if err = l.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) } - err = l.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT") - if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.RavenPostRoutingChain, err)) + if err = l.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-m", "set", "--match-set", vpndriveripset.RavenSkipNatSet, "src,dst", "-j", "ACCEPT"); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) } - return errList + + return nil } func (l *libreswan) computeDesiredConnections(network *types.Network) (map[string]*vpndriver.Connection, map[string]*vpndriver.Connection) { @@ -456,6 +517,24 @@ func whackCmdFn(args ...string) error { return nil } +func ipsecCmdFn(args ...string) (*bytes.Buffer, error) { + var err error + var output bytes.Buffer + for i := 0; i < 5; i++ { + cmd := exec.Command("ipsec", args...) + cmd.Stdout = &output + err = cmd.Run() + if err == nil { + break + } + time.Sleep(1 * time.Second) + } + if err != nil { + return nil, fmt.Errorf("error ipsec with %v, error %s", args, err.Error()) + } + return &output, nil +} + func (l *libreswan) whackDelConnection(conn string) error { return whackCmd("--delete", "--name", conn) } @@ -466,43 +545,36 @@ func connectionName(localID, remoteID, leftSubnet, rightSubnet string) string { func (l *libreswan) Cleanup() error { errList := errorlist.List{} - for name := range l.relayConnections { + connections := currentConnections() + for name := range connections { if err := l.whackDelConnection(name); err != nil { errList = errList.Append(err) klog.ErrorS(err, "fail to delete connection", "connectionName", name) } - if l.centralGw != nil && l.centralGw.NodeName == l.nodeName { - if conn, ok := l.relayConnections[name]; ok && conn != nil { - err := l.deleteRavenSkipNAT(conn) - if err != nil { - errList = errList.Append(err) - } - } - } } - for name := range l.edgeConnections { - if err := l.whackDelConnection(name); err != nil { - errList = errList.Append(err) - klog.ErrorS(err, "fail to delete connection", "connectionName", name) - } - } - l.relayConnections = make(map[string]*vpndriver.Connection) - l.edgeConnections = make(map[string]*vpndriver.Connection) err := netlinkutil.XfrmPolicyFlush() errList = errList.Append(err) + err = netlinkutil.XfrmStateFlush() + errList = errList.Append(err) + + err = vpndriveripset.CleanupRavenSkipNATIPSet() + if err != nil { + errList = errList.Append(fmt.Errorf("error cleanup ipset %s, %s", vpndriveripset.RavenSkipNatSet, err.Error())) + } err = l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) if err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err.Error())) } err = l.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain) if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err)) + errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err.Error())) } err = l.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err)) + errList = errList.Append(fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err.Error())) } + return errList.AsError() } @@ -542,8 +614,7 @@ func (l *libreswan) runPluto() error { func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connection) errorlist.List { errList := errorlist.List{} - if _, ok := l.relayConnections[name]; ok { - klog.InfoS("skipping connect because connection already exists", "connectionName", name) + if _, ok := l.connections[name]; ok { return errList } err := l.whackConnectToEndpoint(name, connection) @@ -552,14 +623,12 @@ func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connect klog.ErrorS(err, "error connect connection", "connectionName", name) return errList } - l.relayConnections[name] = connection return errList } func (l *libreswan) connectToEdgeEndpoint(name string, connection *vpndriver.Connection) errorlist.List { errList := errorlist.List{} - if _, ok := l.edgeConnections[name]; ok { - klog.InfoS("skipping connect because connection already exists", "connectionName", name) + if _, ok := l.connections[name]; ok { return errList } err := l.whackConnectToEdgeEndpoint(name, connection) @@ -568,6 +637,5 @@ func (l *libreswan) connectToEdgeEndpoint(name string, connection *vpndriver.Con klog.ErrorS(err, "error connect connection", "connectionName", name) return errList } - l.edgeConnections[name] = connection return errList } diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go index 529963d..f35e4cd 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go @@ -25,7 +25,6 @@ import ( iptablesutil "github.com/openyurtio/raven/pkg/networkengine/util/iptables" netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" - "github.com/openyurtio/raven/pkg/networkengine/vpndriver" "github.com/openyurtio/raven/pkg/types" ) @@ -130,7 +129,7 @@ func TestLibreswan_Apply(t *testing.T) { nodeName: "localGwNode", // It is unable to set up any vpn connections in such case and should clean up vpn connections expectedConnName: map[string]struct{}{}, - shouldCleanup: true, + shouldCleanup: false, network: &types.Network{ LocalEndpoint: &types.Endpoint{ GatewayName: "localGw", @@ -372,9 +371,8 @@ func TestLibreswan_Apply(t *testing.T) { whackCmd = w.whackCmd a := assert.New(t) l := &libreswan{ - relayConnections: make(map[string]*vpndriver.Connection), - edgeConnections: make(map[string]*vpndriver.Connection), - nodeName: types.NodeName(v.nodeName), + connections: make(map[string]bool), + nodeName: types.NodeName(v.nodeName), } var err error l.iptables, err = iptablesutil.New() diff --git a/pkg/networkengine/vpndriver/wireguard/wireguard.go b/pkg/networkengine/vpndriver/wireguard/wireguard.go index e17b24c..eb01238 100644 --- a/pkg/networkengine/vpndriver/wireguard/wireguard.go +++ b/pkg/networkengine/vpndriver/wireguard/wireguard.go @@ -26,7 +26,6 @@ import ( "strconv" "time" - "github.com/openyurtio/api/raven/v1beta1" "github.com/pkg/errors" "github.com/vdobler/ht/errorlist" "github.com/vishvananda/netlink" @@ -36,10 +35,13 @@ import ( "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/openyurtio/api/raven/v1beta1" "github.com/openyurtio/raven/cmd/agent/app/config" networkutil "github.com/openyurtio/raven/pkg/networkengine/util" + ipsetutil "github.com/openyurtio/raven/pkg/networkengine/util/ipset" iptablesutil "github.com/openyurtio/raven/pkg/networkengine/util/iptables" "github.com/openyurtio/raven/pkg/networkengine/vpndriver" + vpnipset "github.com/openyurtio/raven/pkg/networkengine/vpndriver/ipset" "github.com/openyurtio/raven/pkg/types" "github.com/openyurtio/raven/pkg/utils" ) @@ -61,6 +63,9 @@ const ( DeviceName = "raven-wg0" // DefaultListenPort specifies port of WireGuard listened. DefaultListenPort = 4500 + + ravenSkipNatSet = "raven-skip-nat-set" + ravenSkipNatSetType = "hash:net,net" ) var findCentralGw = vpndriver.FindCentralGwFn @@ -78,10 +83,10 @@ type wireguard struct { psk wgtypes.Key wgLink netlink.Link - relayConnections map[string]*vpndriver.Connection - edgeConnections map[string]*vpndriver.Connection iptables iptablesutil.IPTablesInterface + ipset ipsetutil.IPSetInterface nodeName types.NodeName + centralGw *types.Endpoint ravenClient client.Client listenPort int keepaliveInterval int @@ -93,8 +98,6 @@ func New(cfg *config.Config) (vpndriver.Driver, error) { port = DefaultListenPort } return &wireguard{ - relayConnections: make(map[string]*vpndriver.Connection), - edgeConnections: make(map[string]*vpndriver.Connection), nodeName: types.NodeName(cfg.NodeName), ravenClient: cfg.Manager.GetClient(), listenPort: port, @@ -217,56 +220,78 @@ func (w *wireguard) ensureWgLink(network *types.Network, routeDriverMTUFn func(* return nil } -func (w *wireguard) createConnections(network *types.Network) error { +func (w *wireguard) ensureConnections(network *types.Network) error { desiredEdgeConns, desiredRelayConns, centralAllowedIPs := w.computeDesiredConnections(network) if len(desiredEdgeConns) == 0 && len(desiredRelayConns) == 0 { klog.Infof("no desired connections, cleaning vpn connections") return w.Cleanup() } - klog.Infof("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns) - centralGw := findCentralGw(network) - if err := w.createEdgeConnections(desiredEdgeConns); err != nil { - return err + var err error + + peers := w.currentPeers() + klog.Infof("current peers: %v", peers) + + if err = w.deleteUndesiredPeers(peers, desiredEdgeConns, desiredRelayConns); err != nil { + return fmt.Errorf("ensure edge-edge peers error %s", err.Error()) } - if err := w.createRelayConnections(desiredRelayConns, centralAllowedIPs, centralGw); err != nil { - return err + + if err = w.ensureEdgePeers(desiredEdgeConns); err != nil { + return fmt.Errorf("ensure edge-edge peers error %s", err.Error()) + } + if err = w.ensureRelayPeers(desiredRelayConns, centralAllowedIPs); err != nil { + return fmt.Errorf("ensure cloud-edge peers error %s", err.Error()) + } + + if err = w.ensureRavenSkipNAT(network); err != nil { + return fmt.Errorf("ensure raven skip nat error %s", err.Error()) } return nil } -func (w *wireguard) createEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { - if len(desiredEdgeConns) == 0 { - klog.Infof("no desired edge connections") - return nil +func (w *wireguard) currentPeers() map[string]wgtypes.Peer { + set := make(map[string]wgtypes.Peer) + dev, err := w.wgClient.Device(DeviceName) + if err != nil { + klog.Errorf("can not found wireguard device %s, error %s", DeviceName, err.Error()) + return set } + for _, peer := range dev.Peers { + set[peer.PublicKey.String()] = peer + } + return set +} - for connName, connection := range w.edgeConnections { - if _, ok := desiredEdgeConns[connName]; !ok { - remoteKey := keyFromEndpoint(connection.RemoteEndpoint) - if err := w.removePeer(remoteKey); err == nil { - delete(w.edgeConnections, connName) - } +func (w *wireguard) deleteUndesiredPeers(currentConns map[string]wgtypes.Peer, desiredEdgeConns, desiredRelayConns map[string]*vpndriver.Connection) error { + errList := errorlist.List{} + desiredPeers := make(map[string]struct{}) + for _, connection := range desiredEdgeConns { + desiredPeers[keyFromEndpoint(connection.RemoteEndpoint).String()] = struct{}{} + } + for _, connection := range desiredRelayConns { + desiredPeers[keyFromEndpoint(connection.RemoteEndpoint).String()] = struct{}{} + } + var err error + for key, peer := range currentConns { + if _, ok := desiredPeers[key]; !ok { + err = w.removePeer(&peer.PublicKey) + errList = errList.Append(err) } } + return errList.AsError() +} +func (w *wireguard) ensureEdgePeers(desiredEdgeConns map[string]*vpndriver.Connection) error { + if len(desiredEdgeConns) == 0 { + klog.Infof("no desired edge connections") + return nil + } peerConfigs := make([]wgtypes.PeerConfig, 0) - for name, newConn := range desiredEdgeConns { - newKey := keyFromEndpoint(newConn.RemoteEndpoint) - - if oldConn, ok := w.edgeConnections[name]; ok { - oldKey := keyFromEndpoint(oldConn.RemoteEndpoint) - if oldKey.String() != newKey.String() { - if err := w.removePeer(oldKey); err == nil { - delete(w.edgeConnections, name) - } - } - } - + for _, newConn := range desiredEdgeConns { klog.InfoS("create edge-to-edge connection", "c", newConn) - + newKey := keyFromEndpoint(newConn.RemoteEndpoint) allowedIPs := parseSubnets(newConn.RemoteEndpoint.Subnets) ka := time.Duration(w.keepaliveInterval) var remotePort int @@ -284,59 +309,29 @@ func (w *wireguard) createEdgeConnections(desiredEdgeConns map[string]*vpndriver IP: net.ParseIP(newConn.RemoteEndpoint.PublicIP), Port: remotePort, }, - PersistentKeepaliveInterval: &ka, ReplaceAllowedIPs: true, AllowedIPs: allowedIPs, }) } - - if err := w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ + return w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ ReplacePeers: true, Peers: peerConfigs, - }); err != nil { - return fmt.Errorf("error add peers: %v", err) - } - - w.edgeConnections = desiredEdgeConns - - return nil + }) } -func (w *wireguard) createRelayConnections(desiredRelayConns map[string]*vpndriver.Connection, centralAllowedIPs []string, centralGw *types.Endpoint) error { +func (w *wireguard) ensureRelayPeers(desiredRelayConns map[string]*vpndriver.Connection, centralAllowedIPs []string) error { if len(desiredRelayConns) == 0 { klog.Infof("no desired relay connections") return nil } - - // delete unwanted connections - for connName, connection := range w.relayConnections { - if _, ok := desiredRelayConns[connName]; !ok { - remoteKey := keyFromEndpoint(connection.RemoteEndpoint) - if err := w.removePeer(remoteKey); err == nil { - delete(w.relayConnections, connName) - } - } - } - // add or update connections peerConfigs := make([]wgtypes.PeerConfig, 0) - for name, newConn := range desiredRelayConns { - newKey := keyFromEndpoint(newConn.RemoteEndpoint) - - if oldConn, ok := w.relayConnections[name]; ok { - oldKey := keyFromEndpoint(oldConn.RemoteEndpoint) - if oldKey.String() != newKey.String() { - if err := w.removePeer(oldKey); err == nil { - delete(w.relayConnections, name) - } - } - } - + for _, newConn := range desiredRelayConns { klog.InfoS("create connection", "c", newConn) - + newKey := keyFromEndpoint(newConn.RemoteEndpoint) allowedIPs := parseSubnets(newConn.RemoteEndpoint.Subnets) - if newConn.RemoteEndpoint.NodeName == centralGw.NodeName { + if w.centralGw != nil && newConn.RemoteEndpoint.NodeName == w.centralGw.NodeName { allowedIPs = append(allowedIPs, parseSubnets(centralAllowedIPs)...) } @@ -357,16 +352,10 @@ func (w *wireguard) createRelayConnections(desiredRelayConns map[string]*vpndriv }) } - if err := w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ + return w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ ReplacePeers: false, Peers: peerConfigs, - }); err != nil { - return fmt.Errorf("error add peers: %v", err) - } - - w.relayConnections = desiredRelayConns - - return nil + }) } func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.Network) (int, error)) error { @@ -378,7 +367,7 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N klog.Infof("the current node is not gateway node, cleaning vpn connections") return w.Cleanup() } - + w.centralGw = findCentralGw(network) if _, ok := network.LocalEndpoint.Config[PublicKey]; !ok || network.LocalEndpoint.Config[PublicKey] != w.privateKey.PublicKey().String() { err := w.configGatewayPublicKey(string(network.LocalEndpoint.GatewayName), string(network.LocalEndpoint.NodeName)) if err != nil { @@ -387,24 +376,17 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N return errors.New("retry to config public key") } - centralGw := findCentralGw(network) - if centralGw.NodeName == w.nodeName { - if err := w.ensureRavenSkipNAT(); err != nil { - return fmt.Errorf("error ensure raven skip nat: %s", err) - } - } - if err := w.ensureWgLink(network, routeDriverMTUFn); err != nil { - return fmt.Errorf("fail to ensure wireguar link: %v", err) + return fmt.Errorf("fail to ensure wireguar link: %s", err.Error()) } // 3. Config device route and rules currentRoutes, err := networkutil.ListRoutesOnNode(wgRouteTableID) if err != nil { - return fmt.Errorf("error listing wireguard routes on node: %s", err) + return fmt.Errorf("error listing wireguard routes on node: %s", err.Error()) } currentRules, err := networkutil.ListRulesOnNode(wgRouteTableID) if err != nil { - return fmt.Errorf("error listing wireguard rules on node: %s", err) + return fmt.Errorf("error listing wireguard rules on node: %s", err.Error()) } desiredRoutes := w.calWgRoutes(network) @@ -412,15 +394,52 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N err = networkutil.ApplyRoutes(currentRoutes, desiredRoutes) if err != nil { - return fmt.Errorf("error applying wireguard routes: %s", err) + return fmt.Errorf("error applying wireguard routes: %s", err.Error()) } err = networkutil.ApplyRules(currentRules, desiredRules) if err != nil { - return fmt.Errorf("error applying wireguard rules: %s", err) + return fmt.Errorf("error applying wireguard rules: %s", err.Error()) + } + + if err = w.ensureConnections(network); err != nil { + return fmt.Errorf("error ensure VPN tunnels: %s", err.Error()) } - if err := w.createConnections(network); err != nil { - return fmt.Errorf("error create VPN tunnels: %v", err) + return nil +} + +func (w *wireguard) ensureRavenSkipNAT(network *types.Network) error { + if !vpnipset.IsGatewayRole(network, w.nodeName) { + klog.Infof("node %s is not gateway, skip add skip nat", w.nodeName) + return nil + } + + // The desired and current ipset entries calculated from given network. + // The key is ip set entry + var err error + w.ipset, err = ipsetutil.New(ravenSkipNatSet, ravenSkipNatSetType, ipsetutil.IpsetWrapperOption{KeyFunc: vpnipset.KeyFunc}) + if err != nil { + return fmt.Errorf("error new ipset %s, type %s", vpnipset.RavenSkipNatSet, vpnipset.RavenSkipNatSetType) + } + currentSet, err := networkutil.ListIPSetOnNode(w.ipset) + if err != nil { + return fmt.Errorf("error listing ip set %s on node: %s", w.ipset.Name(), err.Error()) + } + desiredSet := vpnipset.CalIPSetOnNode(network, w.centralGw, w.nodeName, w.ipset) + err = networkutil.ApplyIPSet(w.ipset, currentSet, desiredSet) + if err != nil { + return fmt.Errorf("error applying ip set: %s", err) + } + + // for raven skip nat + if err = w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) + } + if err = w.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", DeviceName, "-j", iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) + } + if err = w.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-m", "set", "--match-set", vpnipset.RavenSkipNatSet, "src,dst", "-j", "ACCEPT"); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) } return nil @@ -457,12 +476,24 @@ func (w *wireguard) Cleanup() error { errList = errList.Append(fmt.Errorf("error delete existing wireguard device %q: %v", DeviceName, err)) } - if err = w.deleteRavenSkipNAT(); err != nil { - errList = errList.Append(err) + err = vpnipset.CleanupRavenSkipNATIPSet() + if err != nil { + errList = errList.Append(fmt.Errorf("error cleanup ipset %s, %s", vpnipset.RavenSkipNatSet, err.Error())) + } + + err = w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + } + err = w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", DeviceName, "-j", iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err)) + } + err = w.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err)) } - w.relayConnections = make(map[string]*vpndriver.Connection) - w.edgeConnections = make(map[string]*vpndriver.Connection) return errList.AsError() } @@ -604,32 +635,3 @@ func parseSubnets(subnets []string) []net.IPNet { } return nets } - -func (w *wireguard) ensureRavenSkipNAT() error { - // for raven skip nat - if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) - } - if err := w.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) - } - if err := w.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-j", "ACCEPT"); err != nil { - return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) - } - - return nil -} - -func (w *wireguard) deleteRavenSkipNAT() error { - if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err) - } - if err := w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err) - } - if err := w.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err) - } - - return nil -}