diff --git a/component/iface/iface.go b/component/iface/iface.go index 1d0219dfac..d543725a3d 100644 --- a/component/iface/iface.go +++ b/component/iface/iface.go @@ -23,7 +23,7 @@ var ( var interfaces = singledo.NewSingle[map[string]*Interface](time.Second * 20) -func ResolveInterface(name string) (*Interface, error) { +func Interfaces() (map[string]*Interface, error) { value, err, _ := interfaces.Do(func() (map[string]*Interface, error) { ifaces, err := net.Interfaces() if err != nil { @@ -69,11 +69,15 @@ func ResolveInterface(name string) (*Interface, error) { return r, nil }) + return value, err +} + +func ResolveInterface(name string) (*Interface, error) { + ifaces, err := Interfaces() if err != nil { return nil, err } - ifaces := value iface, ok := ifaces[name] if !ok { return nil, ErrIfaceNotFound @@ -82,6 +86,21 @@ func ResolveInterface(name string) (*Interface, error) { return iface, nil } +func IsLocalIp(ip netip.Addr) (bool, error) { + ifaces, err := Interfaces() + if err != nil { + return false, err + } + for _, iface := range ifaces { + for _, addr := range iface.Addrs { + if addr.Contains(ip) { + return true, nil + } + } + } + return false, nil +} + func FlushCache() { interfaces.Reset() } diff --git a/component/loopback/detector.go b/component/loopback/detector.go index b07270ed0a..8ec96a9dd8 100644 --- a/component/loopback/detector.go +++ b/component/loopback/detector.go @@ -6,6 +6,7 @@ import ( "net/netip" "github.com/metacubex/mihomo/common/callback" + "github.com/metacubex/mihomo/component/iface" C "github.com/metacubex/mihomo/constant" "github.com/puzpuzpuz/xsync/v3" @@ -15,13 +16,13 @@ var ErrReject = errors.New("reject loopback connection") type Detector struct { connMap *xsync.MapOf[netip.AddrPort, struct{}] - packetConnMap *xsync.MapOf[netip.AddrPort, struct{}] + packetConnMap *xsync.MapOf[uint16, struct{}] } func NewDetector() *Detector { return &Detector{ connMap: xsync.NewMapOf[netip.AddrPort, struct{}](), - packetConnMap: xsync.NewMapOf[netip.AddrPort, struct{}](), + packetConnMap: xsync.NewMapOf[uint16, struct{}](), } } @@ -49,9 +50,10 @@ func (l *Detector) NewPacketConn(conn C.PacketConn) C.PacketConn { if !connAddr.IsValid() { return conn } - l.packetConnMap.Store(connAddr, struct{}{}) + port := connAddr.Port() + l.packetConnMap.Store(port, struct{}{}) return callback.NewCloseCallbackPacketConn(conn, func() { - l.packetConnMap.Delete(connAddr) + l.packetConnMap.Delete(port) }) } @@ -71,7 +73,16 @@ func (l *Detector) CheckPacketConn(metadata *C.Metadata) error { if !connAddr.IsValid() { return nil } - if _, ok := l.packetConnMap.Load(connAddr); ok { + + isLocalIp, err := iface.IsLocalIp(connAddr.Addr()) + if err != nil { + return err + } + if !isLocalIp && !connAddr.Addr().IsLoopback() { + return nil + } + + if _, ok := l.packetConnMap.Load(connAddr.Port()); ok { return fmt.Errorf("%w to: %s", ErrReject, metadata.RemoteAddress()) } return nil