Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix invalid conversion between unsafe.Pointer and uintptr #6673

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions pkg/agent/util/syscall/syscall_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,17 @@ type NetIOInterface interface {

type netIO struct {
syscallN func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno)
// It needs be declared as a variable and replaced during unit tests because the real getIPForwardTable function
// converts a Pointer to a uintptr as an argument of syscallN, while converting a uintptr back to a Pointer in the
// fake syscallN is not valid.
getIPForwardTable func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error)
}

func NewNetIO() NetIOInterface {
return &netIO{syscallN: syscall.SyscallN}
return &netIO{
syscallN: syscall.SyscallN,
getIPForwardTable: getIPForwardTable,
}
}

func (n *netIO) GetIPInterfaceEntry(ipInterfaceRow *MibIPInterfaceRow) (errcode error) {
Expand Down Expand Up @@ -351,8 +358,8 @@ func (n *netIO) freeMibTable(table unsafe.Pointer) {
return
}

func (n *netIO) getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
r0, _, _ := n.syscallN(procGetIPForwardTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(ipForwardTable)))
func getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
r0, _, _ := syscall.SyscallN(procGetIPForwardTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(ipForwardTable)))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
Expand All @@ -362,21 +369,15 @@ func (n *netIO) getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTa
func (n *netIO) ListIPForwardRows(family uint16) ([]MibIPForwardRow, error) {
var table *MibIPForwardTable
err := n.getIPForwardTable(family, &table)
if table != nil {
defer n.freeMibTable(unsafe.Pointer(table))
}
if err != nil {
return nil, os.NewSyscallError("iphlpapi.GetIpForwardTable", err)
}
rows := make([]MibIPForwardRow, table.NumEntries, table.NumEntries)
defer n.freeMibTable(unsafe.Pointer(table))
Copy link
Contributor

@XinShuYang XinShuYang Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for removing “if table != nil” is that when err != nil, the memory release logic has already been correctly handled, so it won't cause a memory leak, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when err is not nil, the syscall failed and there is no memory to release, isn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be, but it makes the original implementation a bit confusing. I don't understand why "if table != nil" check is needed before "if err != nil"(I recall that this part was copied from another project). @wenyingd Do you know where the source code is, and is there's any potential risk? Based on the function doc I found at https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getipforwardtable, the current code change is safe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/tailscale/winipcfg-go/blob/main/wt_mib_ipforward_row2.go#L58
This should be the place where we originally referred.
It should be good with the version in this change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is safe unless the syscall returns a table and an error at the same time, which doesn't seem normal in golang. And wireguard doesn't try to free the memory either when there is an error: https://github.com/WireGuard/wireguard-windows/blob/e70799b1440690e7d4140bffc7c73baf903c7b54/tunnel/winipcfg/winipcfg.go#L145


pFirstRow := uintptr(unsafe.Pointer(&table.Table[0]))
rowSize := unsafe.Sizeof(table.Table[0])

for i := uint32(0); i < table.NumEntries; i++ {
row := *(*MibIPForwardRow)(unsafe.Pointer(pFirstRow + rowSize*uintptr(i)))
rows[i] = row
}
// Copy the rows from the table into a new slice as the table's memory will be freed.
// Since MibIPForwardRow contains only value data (no references), the operation performs a deep copy.
rows := make([]MibIPForwardRow, 0, table.NumEntries)
rows = append(rows, unsafe.Slice(&table.Table[0], table.NumEntries)...)
return rows, nil
}

Expand Down
83 changes: 77 additions & 6 deletions pkg/agent/util/syscall/syscall_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ import (
"os"
"syscall"
"testing"
"unsafe"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRawSockAddrTranslation(t *testing.T) {
Expand Down Expand Up @@ -202,16 +204,85 @@ func TestIPForwardEntryOperations(t *testing.T) {
}
}

func TestListIPForwardRows(t *testing.T) {
func TestListIPForwardRowsFailure(t *testing.T) {
testNetIO := &netIO{
getIPForwardTable: func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
return syscall.Errno(22)
},
syscallN: func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) {
assert.Fail(t, "freeMibTable shouldn't be called")
return
},
}
wantErr := os.NewSyscallError("iphlpapi.GetIpForwardTable", syscall.Errno(22))
testNetIO := NewTestNetIO(22)
// Skipping no error case because converting uintptr back to Pointer is not valid in general.
gotRow, gotErr := testNetIO.ListIPForwardRows(AF_INET)
assert.Nil(t, gotRow)
gotRows, gotErr := testNetIO.ListIPForwardRows(AF_INET)
assert.Nil(t, gotRows)
assert.Equal(t, wantErr, gotErr)
}

func NewTestNetIO(wantR1 uintptr) NetIOInterface {
func TestListIPForwardRowsSuccess(t *testing.T) {
row1 := MibIPForwardRow{
Luid: 10,
Index: 11,
DestinationPrefix: AddressPrefix{
Prefix: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{10, 10, 10, 0},
},
prefixLength: 24,
},
NextHop: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{11, 11, 11, 11},
},
}
row2 := MibIPForwardRow{
Luid: 20,
Index: 21,
DestinationPrefix: AddressPrefix{
Prefix: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{20, 20, 20, 0},
},
prefixLength: 24,
},
NextHop: RawSockAddrInet{
Family: AF_INET,
data: [26]byte{21, 21, 21, 21},
},
}
// The table contains two rows. Its memory address will be assigned to ipForwardTable when getIPForwardTable is called.
table := struct {
NumEntries uint32
Table [2]MibIPForwardRow
}{
NumEntries: 2,
Table: [2]MibIPForwardRow{row1, row2},
}
freeMibTableCalled := false
testNetIO := &netIO{
getIPForwardTable: func(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) {
*ipForwardTable = (*MibIPForwardTable)(unsafe.Pointer(&table))
return nil
},
syscallN: func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) {
freeMibTableCalled = true
// Reset the rows.
table.Table[0] = MibIPForwardRow{}
table.Table[1] = MibIPForwardRow{}
return
},
}
gotRows, gotErr := testNetIO.ListIPForwardRows(AF_INET)
require.NoError(t, gotErr)
assert.True(t, freeMibTableCalled)
// It verifies that the returned rows are independent copies, not referencing to the original table's memory, by
// asserting they retain the exact same content as the original table whose rows have been reset by freeMibTable.
expectedRows := []MibIPForwardRow{row1, row2}
assert.Equal(t, expectedRows, gotRows)
}

func NewTestNetIO(wantR1 uintptr) *netIO {
mockSyscallN := func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) {
return wantR1, 0, 0
}
Expand Down
Loading