Skip to content

Commit

Permalink
inet: ScopedLwIPLock for better safety and added locks at necessary p…
Browse files Browse the repository at this point in the history
…laces (#28655)

* inet: scoped lwip locks for better safety and add at few more places

* Do not static assert if LWIP_TCPIP_CORE_LOCKING is disabled

* Add scope for locks

* move out the error variable definition to the top
  • Loading branch information
shubhamdp authored and pull[bot] committed Mar 6, 2024
1 parent 42849ec commit d630937
Showing 1 changed file with 87 additions and 79 deletions.
166 changes: 87 additions & 79 deletions src/inet/UDPEndPointImplLwIP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,26 @@ static_assert(LWIP_VERSION_MAJOR > 1, "CHIP requires LwIP 2.0 or later");
namespace chip {
namespace Inet {

namespace {
/**
* @brief
* RAII locking for LwIP core to simplify management of
* LOCK_TCPIP_CORE()/UNLOCK_TCPIP_CORE() calls.
*/
class ScopedLwIPLock
{
public:
ScopedLwIPLock() { LOCK_TCPIP_CORE(); }
~ScopedLwIPLock() { UNLOCK_TCPIP_CORE(); }
};
} // anonymous namespace

EndpointQueueFilter * UDPEndPointImplLwIP::sQueueFilter = nullptr;

CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddress & address, uint16_t port,
InterfaceId interfaceId)
{
// Lock LwIP stack
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

// Make sure we have the appropriate type of PCB.
CHIP_ERROR res = GetPCB(addressType);
Expand All @@ -90,9 +103,6 @@ CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddr
res = LwIPBindInterface(mUDP, interfaceId);
}

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();

return res;
}

Expand All @@ -101,7 +111,7 @@ CHIP_ERROR UDPEndPointImplLwIP::BindInterfaceImpl(IPAddressType addrType, Interf
// A lock is required because the LwIP thread may be referring to intf_filter,
// while this code running in the Inet application is potentially modifying it.
// NOTE: this only supports LwIP interfaces whose number is no bigger than 9.
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

// Make sure we have the appropriate type of PCB.
CHIP_ERROR err = GetPCB(addrType);
Expand All @@ -110,9 +120,6 @@ CHIP_ERROR UDPEndPointImplLwIP::BindInterfaceImpl(IPAddressType addrType, Interf
{
err = LwIPBindInterface(mUDP, intfId);
}

UNLOCK_TCPIP_CORE();

return err;
}

Expand All @@ -134,6 +141,8 @@ CHIP_ERROR UDPEndPointImplLwIP::LwIPBindInterface(struct udp_pcb * aUDP, Interfa

InterfaceId UDPEndPointImplLwIP::GetBoundInterface() const
{
ScopedLwIPLock lwipLock;

#if HAVE_LWIP_UDP_BIND_NETIF
return InterfaceId(netif_get_by_index(mUDP->netif_idx));
#else
Expand All @@ -148,14 +157,9 @@ uint16_t UDPEndPointImplLwIP::GetBoundPort() const

CHIP_ERROR UDPEndPointImplLwIP::ListenImpl()
{
// Lock LwIP stack
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

udp_recv(mUDP, LwIPReceiveUDPMessage, this);

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();

return CHIP_NO_ERROR;
}

Expand All @@ -174,53 +178,53 @@ CHIP_ERROR UDPEndPointImplLwIP::SendMsgImpl(const IPPacketInfo * pktInfo, System
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY);
}

// Lock LwIP stack
LOCK_TCPIP_CORE();
CHIP_ERROR res = CHIP_NO_ERROR;
err_t lwipErr = ERR_VAL;

// Make sure we have the appropriate type of PCB based on the destination address.
CHIP_ERROR res = GetPCB(destAddr.Type());
if (res != CHIP_NO_ERROR)
// Adding a scope here to unlock the LwIP core when the lock is no longer required.
{
UNLOCK_TCPIP_CORE();
return res;
}
ScopedLwIPLock lwipLock;

// Send the message to the specified address/port.
// If an outbound interface has been specified, call a specific version of the UDP sendto()
// function that accepts the target interface.
// If a source address has been specified, temporarily override the local_ip of the PCB.
// This results in LwIP using the given address being as the source address for the generated
// packet, as if the PCB had been bound to that address.
err_t lwipErr = ERR_VAL;
const IPAddress & srcAddr = pktInfo->SrcAddress;
const uint16_t & destPort = pktInfo->DestPort;
const InterfaceId & intfId = pktInfo->Interface;
// Make sure we have the appropriate type of PCB based on the destination address.
res = GetPCB(destAddr.Type());
if (res != CHIP_NO_ERROR)
{
return res;
}

ip_addr_t lwipSrcAddr = srcAddr.ToLwIPAddr();
ip_addr_t lwipDestAddr = destAddr.ToLwIPAddr();
// Send the message to the specified address/port.
// If an outbound interface has been specified, call a specific version of the UDP sendto()
// function that accepts the target interface.
// If a source address has been specified, temporarily override the local_ip of the PCB.
// This results in LwIP using the given address being as the source address for the generated
// packet, as if the PCB had been bound to that address.
const IPAddress & srcAddr = pktInfo->SrcAddress;
const uint16_t & destPort = pktInfo->DestPort;
const InterfaceId & intfId = pktInfo->Interface;

ip_addr_t boundAddr;
ip_addr_copy(boundAddr, mUDP->local_ip);
ip_addr_t lwipSrcAddr = srcAddr.ToLwIPAddr();
ip_addr_t lwipDestAddr = destAddr.ToLwIPAddr();

if (!ip_addr_isany(&lwipSrcAddr))
{
ip_addr_copy(mUDP->local_ip, lwipSrcAddr);
}
ip_addr_t boundAddr;
ip_addr_copy(boundAddr, mUDP->local_ip);

if (intfId.IsPresent())
{
lwipErr = udp_sendto_if(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort,
intfId.GetPlatformInterface());
}
else
{
lwipErr = udp_sendto(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort);
}
if (!ip_addr_isany(&lwipSrcAddr))
{
ip_addr_copy(mUDP->local_ip, lwipSrcAddr);
}

ip_addr_copy(mUDP->local_ip, boundAddr);
if (intfId.IsPresent())
{
lwipErr = udp_sendto_if(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort,
intfId.GetPlatformInterface());
}
else
{
lwipErr = udp_sendto(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort);
}

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();
ip_addr_copy(mUDP->local_ip, boundAddr);
}

if (lwipErr != ERR_OK)
{
Expand All @@ -232,9 +236,7 @@ CHIP_ERROR UDPEndPointImplLwIP::SendMsgImpl(const IPPacketInfo * pktInfo, System

void UDPEndPointImplLwIP::CloseImpl()
{

// Lock LwIP stack
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

// Since UDP PCB is released synchronously here, but UDP endpoint itself might have to wait
// for destruction asynchronously, there could be more allocated UDP endpoints than UDP PCBs.
Expand All @@ -260,9 +262,6 @@ void UDPEndPointImplLwIP::CloseImpl()
}
}
}

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();
}

void UDPEndPointImplLwIP::Free()
Expand Down Expand Up @@ -473,19 +472,23 @@ CHIP_ERROR UDPEndPointImplLwIP::IPv4JoinLeaveMulticastGroupImpl(InterfaceId aInt
const ip4_addr_t lIPv4Address = aAddress.ToIPv4();
err_t lStatus;

if (aInterfaceId.IsPresent())
{
ScopedLwIPLock lwipLock;

struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);
if (aInterfaceId.IsPresent())
{

lStatus = join ? igmp_joingroup_netif(lNetif, &lIPv4Address) //
: igmp_leavegroup_netif(lNetif, &lIPv4Address);
}
else
{
lStatus = join ? igmp_joingroup(IP4_ADDR_ANY4, &lIPv4Address) //
: igmp_leavegroup(IP4_ADDR_ANY4, &lIPv4Address);
struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);

lStatus = join ? igmp_joingroup_netif(lNetif, &lIPv4Address) //
: igmp_leavegroup_netif(lNetif, &lIPv4Address);
}
else
{
lStatus = join ? igmp_joingroup(IP4_ADDR_ANY4, &lIPv4Address) //
: igmp_leavegroup(IP4_ADDR_ANY4, &lIPv4Address);
}
}

if (lStatus == ERR_MEM)
Expand All @@ -504,17 +507,22 @@ CHIP_ERROR UDPEndPointImplLwIP::IPv6JoinLeaveMulticastGroupImpl(InterfaceId aInt
#ifdef HAVE_IPV6_MULTICAST
const ip6_addr_t lIPv6Address = aAddress.ToIPv6();
err_t lStatus;
if (aInterfaceId.IsPresent())
{
struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);
lStatus = join ? mld6_joingroup_netif(lNetif, &lIPv6Address) //
: mld6_leavegroup_netif(lNetif, &lIPv6Address);
}
else

{
lStatus = join ? mld6_joingroup(IP6_ADDR_ANY6, &lIPv6Address) //
: mld6_leavegroup(IP6_ADDR_ANY6, &lIPv6Address);
ScopedLwIPLock lwipLock;

if (aInterfaceId.IsPresent())
{
struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);
lStatus = join ? mld6_joingroup_netif(lNetif, &lIPv6Address) //
: mld6_leavegroup_netif(lNetif, &lIPv6Address);
}
else
{
lStatus = join ? mld6_joingroup(IP6_ADDR_ANY6, &lIPv6Address) //
: mld6_leavegroup(IP6_ADDR_ANY6, &lIPv6Address);
}
}

if (lStatus == ERR_MEM)
Expand Down

0 comments on commit d630937

Please sign in to comment.