Skip to content

Commit

Permalink
Don't use global logrus logger
Browse files Browse the repository at this point in the history
  • Loading branch information
jschwinger233 committed Jan 23, 2024
1 parent a8e2707 commit 4019fb4
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 25 deletions.
5 changes: 5 additions & 0 deletions control/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,11 @@ func NewControlPlane(
}
go dnsUpstream.InitUpstreams()

InitDaeNetns(log)
if err = InitSysctlManager(log); err != nil {
return nil, err
}

close(plane.ready)
return plane, nil
}
Expand Down
8 changes: 4 additions & 4 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (er
// resp is valid.
cache2 := c.LookupDnsRespCache(c.cacheKey(qname, qtype2), true)
if c.qtypePrefer == qtype || cache2 == nil || !cache2.IncludeAnyIp() {
return sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn)
return sendPkt(c.log, resp, req.realDst, req.realSrc, req.src, req.lConn)
} else {
return c.sendReject_(dnsMessage, req)
}
Expand Down Expand Up @@ -453,7 +453,7 @@ func (c *DnsController) handle_(
if resp := c.LookupDnsRespCache_(dnsMessage, cacheKey, false); resp != nil {
// Send cache to client directly.
if needResp {
if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
if err = sendPkt(c.log, resp, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
return fmt.Errorf("failed to write cached DNS resp: %w", err)
}
}
Expand Down Expand Up @@ -501,7 +501,7 @@ func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Msg, req *udpRequest)
if err != nil {
return fmt.Errorf("pack DNS packet: %w", err)
}
if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
if err = sendPkt(c.log, data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -751,7 +751,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
if err != nil {
return err
}
if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
if err = sendPkt(c.log, data, req.realDst, req.realSrc, req.src, req.lConn); err != nil {
return err
}
}
Expand Down
18 changes: 11 additions & 7 deletions control/netns_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ var (
)

type DaeNetns struct {
log *logrus.Logger

setupDone atomic.Bool
mu sync.Mutex

dae0, dae0peer netlink.Link
hostNs, daeNs netns.NsHandle
}

func init() {
daeNetns = &DaeNetns{}
func InitDaeNetns(log *logrus.Logger) {
daeNetns = &DaeNetns{
log: log,
}
}

func GetDaeNetns() *DaeNetns {
Expand Down Expand Up @@ -85,7 +89,7 @@ func (ns *DaeNetns) With(f func() error) (err error) {
}

func (ns *DaeNetns) setup() (err error) {
logrus.Trace("setting up dae netns")
ns.log.Trace("setting up dae netns")

runtime.LockOSThread()
defer runtime.UnlockOSThread()
Expand Down Expand Up @@ -286,17 +290,17 @@ func (ns *DaeNetns) monitorDae0LinkAddr() {

err := netlink.LinkSubscribe(ch, done)
if err != nil {
logrus.Errorf("failed to subscribe link updates: %v", err)
ns.log.Errorf("failed to subscribe link updates: %v", err)
}
if ns.dae0, err = netlink.LinkByName(HostVethName); err != nil {
logrus.Errorf("failed to get link dae0: %v", err)
ns.log.Errorf("failed to get link dae0: %v", err)
}
if err = ns.updateNeigh(); err != nil {
logrus.Errorf("failed to update neigh: %v", err)
ns.log.Errorf("failed to update neigh: %v", err)
}
for msg := range ch {
if msg.Link.Attrs().Name == HostVethName && !bytes.Equal(msg.Link.Attrs().HardwareAddr, ns.dae0.Attrs().HardwareAddr) {
logrus.WithField("old addr", ns.dae0.Attrs().HardwareAddr).WithField("new addr", msg.Link.Attrs().HardwareAddr).Info("dae0 link addr changed")
ns.log.WithField("old addr", ns.dae0.Attrs().HardwareAddr).WithField("new addr", msg.Link.Attrs().HardwareAddr).Info("dae0 link addr changed")
ns.dae0 = msg.Link
ns.updateNeigh()
}
Expand Down
22 changes: 11 additions & 11 deletions control/sysctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@ const SysctlPrefixPath = "/proc/sys/"
var sysctl *SysctlManager

type SysctlManager struct {
log *logrus.Logger
mux sync.Mutex
watcher *fsnotify.Watcher
expectations map[string]string
}

func init() {
var err error
if sysctl, err = NewSysctlManager(); err != nil {
logrus.Fatalf("failed to create sysctl manager: %v", err)
}
func InitSysctlManager(log *logrus.Logger) (err error) {
sysctl, err = NewSysctlManager(log)
return err
}

func NewSysctlManager() (*SysctlManager, error) {
func NewSysctlManager(log *logrus.Logger) (*SysctlManager, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}

manager := &SysctlManager{
log: log,
mux: sync.Mutex{},
watcher: watcher,
expectations: map[string]string{},
Expand All @@ -49,20 +49,20 @@ func (s *SysctlManager) startWatch() {
return
}
if event.Has(fsnotify.Write) {
logrus.Tracef("sysctl write event: %+v", event)
s.log.Tracef("sysctl write event: %+v", event)
s.mux.Lock()
expected, ok := s.expectations[event.Name]
s.mux.Unlock()
if ok {
raw, err := os.ReadFile(event.Name)
if err != nil {
logrus.Errorf("failed to read sysctl file %s: %v", event.Name, err)
s.log.Errorf("failed to read sysctl file %s: %v", event.Name, err)
}
value := strings.TrimSpace(string(raw))
if value != expected {
logrus.Infof("sysctl %s has unexpected value %s, expected %s", event.Name, value, expected)
s.log.Infof("sysctl %s has unexpected value %s, expected %s", event.Name, value, expected)
if err := os.WriteFile(event.Name, []byte(expected), 0644); err != nil {
logrus.Errorf("failed to write sysctl file %s: %v", event.Name, err)
s.log.Errorf("failed to write sysctl file %s: %v", event.Name, err)
}
}
}
Expand All @@ -71,7 +71,7 @@ func (s *SysctlManager) startWatch() {
if !ok {
return
}
logrus.Errorf("sysctl watcher error: %v", err)
s.log.Errorf("sysctl watcher error: %v", err)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions control/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout
}

// sendPkt uses bind first, and fallback to send hdr if addr is in use.
func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn) (err error) {
func sendPkt(log *logrus.Logger, data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn) (err error) {

transparentTimeout := AnyfromTimeout
if from.Port() == 53 {
Expand All @@ -58,7 +58,7 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn
}
uConn, _, err := DefaultAnyfromPool.GetOrCreate(from.String(), transparentTimeout)
if err != nil && errors.Is(err, syscall.EADDRINUSE) {
logrus.WithField("from", from).
log.WithField("from", from).
WithField("to", to).
WithField("realTo", realTo).
Trace("Port in use, fallback to use netns.")
Expand Down Expand Up @@ -187,7 +187,7 @@ getNew:
// Handler handles response packets and send it to the client.
Handler: func(data []byte, from netip.AddrPort) (err error) {
// Do not return conn-unrelated err in this func.
return sendPkt(data, from, realSrc, src, lConn)
return sendPkt(c.log, data, from, realSrc, src, lConn)
},
NatTimeout: natTimeout,
GetDialOption: func() (option *DialOption, err error) {
Expand Down

0 comments on commit 4019fb4

Please sign in to comment.