Skip to content

Commit

Permalink
Minor code refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
beevik committed May 31, 2023
1 parent 036a5fe commit c527eb7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
24 changes: 12 additions & 12 deletions ntp.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func (m *msg) getLeap() LeapIndicator {
return LeapIndicator((m.LiVnMode >> 6) & 0x03)
}

// DialFunc is a function that connects to the remote network address and port
// from the local network address and port when using QueryWithOptions.
type DialFunc func(laddr string, lport int, raddr string, rport int) (net.Conn, error)

// QueryOptions contains the list of configurable options that may be used
// with the QueryWithOptions function.
type QueryOptions struct {
Expand All @@ -168,9 +172,7 @@ type QueryOptions struct {
LocalAddress string // IP address to use for the client address
Port int // Server port, defaults to 123
TTL int // IP TTL to use, defaults to system default

// Dial allows the user to override the default UDP dialer behavior when contacting the remote NTP server.
Dial func(localAddress string, localPort int, remoteAddress string, remotePort int) (net.Conn, error)
Dial DialFunc // Overrides the use of the default UDP dialer
}

// A Response contains time data, some of which is returned by the NTP server
Expand Down Expand Up @@ -426,25 +428,23 @@ func getTime(host string, opt QueryOptions) (*msg, ntpTime, error) {
}

// defaultDial provides a UDP dialer based on Go's built-in net stack.
func defaultDial(localAddress string, localPort int, remoteAddress string, remotePort int) (net.Conn, error) {
raddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(remoteAddress, strconv.Itoa(remotePort)))
func defaultDial(localAddr string, localPort int, remoteAddr string, remotePort int) (net.Conn, error) {
rhostport := net.JoinHostPort(remoteAddr, strconv.Itoa(remotePort))
raddr, err := net.ResolveUDPAddr("udp", rhostport)
if err != nil {
return nil, err
}

var laddr *net.UDPAddr
if localAddress != "" {
laddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(localAddress, strconv.Itoa(localPort)))
if localAddr != "" {
lhostport := net.JoinHostPort(localAddr, strconv.Itoa(localPort))
laddr, err = net.ResolveUDPAddr("udp", lhostport)
if err != nil {
return nil, err
}
}

con, err := net.DialUDP("udp", laddr, raddr)
if err != nil {
return nil, err
}
return con, err
return net.DialUDP("udp", laddr, raddr)
}

// parseTime parses the NTP packet along with the packet receive time to
Expand Down
37 changes: 21 additions & 16 deletions ntp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,24 +330,29 @@ func TestOfflineKissCode(t *testing.T) {
}

func TestOfflineCustomDialer(t *testing.T) {
ntpHost := "remote"
localHost := "local"
raddr := "remote"
laddr := "local"
dialerCalled := false

qo := QueryOptions{
LocalAddress: localHost,
Dial: func(la string, lp int, ra string, rp int) (net.Conn, error) {
assert.Equal(t, la, localHost)
assert.Equal(t, ra, ntpHost)
assert.Equal(t, rp, 123)
// Only expect to be called once:
assert.False(t, dialerCalled)

dialerCalled = true
return nil, errors.New("not dialing")
},
notDialingErr := errors.New("not dialing")

customDialer := func(la string, lp int, ra string, rp int) (net.Conn, error) {
assert.Equal(t, laddr, la)
assert.Equal(t, 0, lp)
assert.Equal(t, raddr, ra)
assert.Equal(t, 123, rp)
// Only expect to be called once:
assert.False(t, dialerCalled)

dialerCalled = true
return nil, notDialingErr
}
_, _ = QueryWithOptions(ntpHost, qo)

opt := QueryOptions{
LocalAddress: laddr,
Dial: customDialer,
}
r, err := QueryWithOptions(raddr, opt)
assert.Nil(t, r)
assert.Equal(t, notDialingErr, err)
assert.True(t, dialerCalled)
}

0 comments on commit c527eb7

Please sign in to comment.