diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index ce273c5da54b..bae3b106e1d2 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -163,25 +163,41 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in if !destinationOverridden { writer = &buf.SequentialWriter{Writer: conn} } else { - sockopt := &internet.SocketConfig{ - Tproxy: internet.SocketConfig_TProxy, - } + var addr *net.UDPAddr + var mark int if dest.Address.Family().IsIP() { - sockopt.BindAddress = dest.Address.IP() - sockopt.BindPort = uint32(dest.Port) + addr = &net.UDPAddr{ + IP: dest.Address.IP(), + Port: int(dest.Port), + } } if d.sockopt != nil { - sockopt.Mark = d.sockopt.Mark + mark = int(d.sockopt.Mark) } - to := net.DestinationFromAddr(conn.RemoteAddr()) - tConn, err := internet.DialSystem(ctx, to, sockopt) + pConn, err := FakeUDP(addr, mark) if err != nil { return err } - writer = NewPacketWriter(tConn, &dest, ctx, &to, sockopt) + back := net.DestinationFromAddr(conn.RemoteAddr()) + writer = NewPacketWriter(pConn, &dest, mark, &back) defer writer.(*PacketWriter).Close() /* + sockopt := &internet.SocketConfig{ + Tproxy: internet.SocketConfig_TProxy, + } + if dest.Address.Family().IsIP() { + sockopt.BindAddress = dest.Address.IP() + sockopt.BindPort = uint32(dest.Port) + } + if d.sockopt != nil { + sockopt.Mark = d.sockopt.Mark + } + tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt) + if err != nil { + return err + } defer tConn.Close() + writer = &buf.SequentialWriter{Writer: tConn} tReader := buf.NewPacketReader(tConn) requestCount++ @@ -220,24 +236,25 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in return nil } -func NewPacketWriter(conn net.Conn, d *net.Destination, ctx context.Context, to *net.Destination, sockopt *internet.SocketConfig) buf.Writer { +func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.Destination) buf.Writer { writer := &PacketWriter{ - conn: conn, - conns: make(map[net.Destination]net.Conn), - ctx: ctx, - to: to, - sockopt: sockopt, + conn: conn, + conns: make(map[net.Destination]net.PacketConn), + mark: mark, + back: &net.UDPAddr{ + IP: back.Address.IP(), + Port: int(back.Port), + }, } writer.conns[*d] = conn return writer } type PacketWriter struct { - conn net.Conn - conns map[net.Destination]net.Conn - ctx context.Context - to *net.Destination - sockopt *internet.SocketConfig + conn net.PacketConn + conns map[net.Destination]net.PacketConn + mark int + back *net.UDPAddr } func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { @@ -251,23 +268,34 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { if b.UDP != nil && b.UDP.Address.Family().IsIP() { conn := w.conns[*b.UDP] if conn == nil { - w.sockopt.BindAddress = b.UDP.Address.IP() - w.sockopt.BindPort = uint32(b.UDP.Port) - conn, _ = internet.DialSystem(w.ctx, *w.to, w.sockopt) - if conn == nil { + conn, err = FakeUDP( + &net.UDPAddr{ + IP: b.UDP.Address.IP(), + Port: int(b.UDP.Port), + }, + w.mark, + ) + if err != nil { b.Release() - continue + buf.ReleaseMulti(mb) + return err } w.conns[*b.UDP] = conn } - _, err = conn.Write(b.Bytes()) + _, err = conn.WriteTo(b.Bytes(), w.back) + if err != nil { + conn.Close() + w.conns[*b.UDP] = nil + newError(err).WriteToLog() + } + b.Release() } else { - _, err = w.conn.Write(b.Bytes()) - } - b.Release() - if err != nil { - buf.ReleaseMulti(mb) - return err + _, err = w.conn.WriteTo(b.Bytes(), w.back) + b.Release() + if err != nil { + buf.ReleaseMulti(mb) + return err + } } } return nil diff --git a/proxy/dokodemo/fakeudp_linux.go b/proxy/dokodemo/fakeudp_linux.go new file mode 100644 index 000000000000..f793661bdf8f --- /dev/null +++ b/proxy/dokodemo/fakeudp_linux.go @@ -0,0 +1,85 @@ +// +build linux + +package dokodemo + +import ( + "fmt" + "net" + "os" + "strconv" + "syscall" +) + +func FakeUDP(addr *net.UDPAddr, mark int) (net.PacketConn, error) { + + if addr == nil { + addr = &net.UDPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + } + } + + localSocketAddress, af, err := udpAddrToSocketAddr(addr) + if err != nil { + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("build local socket address: %s", err)} + } + + fileDescriptor, err := syscall.Socket(af, syscall.SOCK_DGRAM, 0) + if err != nil { + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("socket open: %s", err)} + } + + if mark != 0 { + if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_SOCKET, syscall.SO_MARK, mark); err != nil { + syscall.Close(fileDescriptor) + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("set socket option: SO_MARK: %s", err)} + } + } + + if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { + syscall.Close(fileDescriptor) + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("set socket option: SO_REUSEADDR: %s", err)} + } + + if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil { + syscall.Close(fileDescriptor) + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)} + } + + if err = syscall.Bind(fileDescriptor, localSocketAddress); err != nil { + syscall.Close(fileDescriptor) + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("socket bind: %s", err)} + } + + fdFile := os.NewFile(uintptr(fileDescriptor), fmt.Sprintf("net-udp-fake-%s", addr.String())) + defer fdFile.Close() + + packetConn, err := net.FilePacketConn(fdFile) + if err != nil { + syscall.Close(fileDescriptor) + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("convert file descriptor to connection: %s", err)} + } + + return packetConn, nil +} + +func udpAddrToSocketAddr(addr *net.UDPAddr) (syscall.Sockaddr, int, error) { + switch { + case addr.IP.To4() != nil: + ip := [4]byte{} + copy(ip[:], addr.IP.To4()) + + return &syscall.SockaddrInet4{Addr: ip, Port: addr.Port}, syscall.AF_INET, nil + + default: + ip := [16]byte{} + copy(ip[:], addr.IP.To16()) + + zoneID, err := strconv.ParseUint(addr.Zone, 10, 32) + if err != nil { + return nil, 0, err + } + + return &syscall.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, syscall.AF_INET6, nil + } +} diff --git a/proxy/dokodemo/fakeudp_other.go b/proxy/dokodemo/fakeudp_other.go new file mode 100644 index 000000000000..abad344e0946 --- /dev/null +++ b/proxy/dokodemo/fakeudp_other.go @@ -0,0 +1,12 @@ +// +build !linux + +package dokodemo + +import ( + "fmt" + "net" +) + +func FakeUDP(addr *net.UDPAddr, mark int) (net.PacketConn, error) { + return nil, &net.OpError{Op: "fake", Err: fmt.Errorf("!linux")} +}