From c03509fd81a30b44373df5a64b5c6b0a3d1592e0 Mon Sep 17 00:00:00 2001 From: ignatella Date: Fri, 13 Sep 2024 17:59:52 +0200 Subject: [PATCH 1/2] Add: netlink socket options --- conn.go | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index a9fbf2b..cdbc44c 100644 --- a/conn.go +++ b/conn.go @@ -37,16 +37,20 @@ type Conn struct { TestDial nltest.Func // for testing only; passed to nltest.Dial NetNS int // fd referencing the network namespace netlink will interact with. - lasting bool // establish a lasting connection to be used across multiple netlink operations. - mu sync.Mutex // protects the following state - messages []netlink.Message - err error - nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. + lasting bool // establish a lasting connection to be used across multiple netlink operations. + mu sync.Mutex // protects the following state + messages []netlink.Message + err error + nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. + sockOptions []SockOption } // ConnOption is an option to change the behavior of the nftables Conn returned by Open. type ConnOption func(*Conn) +// SockOption is an option to change the behavior of the netlink socket used by the nftables Conn. +type SockOption func(*netlink.Conn) error + // New returns a netlink connection for querying and modifying nftables. Some // aspects of the new netlink connection can be configured using the options // WithNetNSFd, WithTestDial, and AsLasting. @@ -101,6 +105,14 @@ func WithTestDial(f nltest.Func) ConnOption { } } +// WithSockOptions sets the specified socket options when creating a new netlink +// connection. +func WithSockOptions(opts ...SockOption) ConnOption { + return func(cc *Conn) { + cc.sockOptions = append(cc.sockOptions, opts...) + } +} + // netlinkCloser is returned by netlinkConn(UnderLock) and must be called after // being done with the returned netlink connection in order to properly close // this connection, if necessary. @@ -284,11 +296,30 @@ func (cc *Conn) FlushRuleset() { } func (cc *Conn) dialNetlink() (*netlink.Conn, error) { + var ( + conn *netlink.Conn + err error = nil + ) + if cc.TestDial != nil { - return nltest.Dial(cc.TestDial), nil + conn = nltest.Dial(cc.TestDial) + } else { + conn, err = netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS}) + } + + if err != nil { + return nil, err + } + + for _, opt := range cc.sockOptions { + err := opt(conn) + + if err != nil { + return nil, err + } } - return netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS}) + return conn, nil } func (cc *Conn) setErr(err error) { From ec6e32d656afa11da9f1de0a8fb8aab3ac965541 Mon Sep 17 00:00:00 2001 From: ignatella Date: Mon, 23 Sep 2024 09:59:28 +0200 Subject: [PATCH 2/2] Update: code refactoring --- conn.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index cdbc44c..25d88e0 100644 --- a/conn.go +++ b/conn.go @@ -298,7 +298,7 @@ func (cc *Conn) FlushRuleset() { func (cc *Conn) dialNetlink() (*netlink.Conn, error) { var ( conn *netlink.Conn - err error = nil + err error ) if cc.TestDial != nil { @@ -312,9 +312,7 @@ func (cc *Conn) dialNetlink() (*netlink.Conn, error) { } for _, opt := range cc.sockOptions { - err := opt(conn) - - if err != nil { + if err := opt(conn); err != nil { return nil, err } }