Skip to content

Commit

Permalink
fix: implement DialContext for dialers
Browse files Browse the repository at this point in the history
  • Loading branch information
mzz2017 committed Sep 18, 2024
1 parent ff62fae commit 1d3f903
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 21 deletions.
6 changes: 2 additions & 4 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -372,8 +371,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions common/netutils/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
}

// Dial and write.
cd := &netproxy.ContextDialerConverter{Dialer: d}
c, err := cd.DialContext(ctx, network, dns.String())
c, err := d.DialContext(ctx, network, dns.String())
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions component/outbound/dialer/connectivity_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,11 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
if method == "" {
method = http.MethodGet
}
cd := &netproxy.ContextDialerConverter{Dialer: d.Dialer}
cli := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
// Force to dial "ip".
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port()))
conn, err := d.Dialer.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port()))
if err != nil {
return nil, err
}
Expand Down
7 changes: 2 additions & 5 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,16 +560,13 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte

ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
bestContextDialer := netproxy.ContextDialerConverter{
Dialer: dialArgument.bestDialer,
}

switch dialArgument.l4proto {
case consts.L4ProtoStr_UDP:
// Get udp endpoint.

// TODO: connection pool.
conn, err = bestContextDialer.DialContext(
conn, err = dialArgument.bestDialer.DialContext(
ctxDial,
common.MagicNetwork("udp", dialArgument.mark, dialArgument.mptcp),
dialArgument.bestTarget.String(),
Expand Down Expand Up @@ -634,7 +631,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine.

conn, err = bestContextDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String())
conn, err = dialArgument.bestDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String())
if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
}
Expand Down
5 changes: 1 addition & 4 deletions control/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ func (c *ControlPlane) RouteDialTcp(p *RouteDialParam) (conn netproxy.Conn, err
}
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
cd := netproxy.ContextDialerConverter{
Dialer: d,
}
return cd.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget)
return d.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget)
}

type WriteCloser interface {
Expand Down
5 changes: 1 addition & 4 deletions control/udp_endpoint_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,9 @@ begin:
if err != nil {
return nil, false, err
}
cd := netproxy.ContextDialerConverter{
Dialer: dialOption.Dialer,
}
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
udpConn, err := cd.DialContext(ctx, dialOption.Network, dialOption.Target)
udpConn, err := dialOption.Dialer.DialContext(ctx, dialOption.Network, dialOption.Target)
if err != nil {
return nil, true, err
}
Expand Down

0 comments on commit 1d3f903

Please sign in to comment.