Skip to content

Commit

Permalink
unauthenticated-session
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Sep 2, 2021
1 parent e071bd6 commit 030f20d
Show file tree
Hide file tree
Showing 29 changed files with 532 additions and 190 deletions.
3 changes: 1 addition & 2 deletions src/app/server/RendezvousServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params,
ReturnErrorOnFailure(mPairingSession.WaitForPairing(params.GetSetupPINCode(), pbkdf2IterCount, salt, keyID, this));
}

ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr));
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());
ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mSessionMgr));

return CHIP_NO_ERROR;
}
Expand Down
17 changes: 13 additions & 4 deletions src/channel/ChannelContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,22 @@ void ChannelContext::EnterCasePairingState()
auto & prepare = GetPrepareVars();
prepare.mCasePairingSession = Platform::New<CASESession>();

ExchangeContext * ctxt =
mExchangeManager->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), prepare.mCasePairingSession);
VerifyOrReturn(ctxt != nullptr);

// TODO: currently only supports IP/UDP paring
Transport::PeerAddress addr;
addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress);

auto session = mExchangeManager->GetSessionMgr()->CreateUnauthenticatedSession(addr);
if (!session.HasValue())
{
ExitCasePairingState();
ExitPreparingState();
EnterFailedState(CHIP_ERROR_NO_MEMORY);
return;
}

ExchangeContext * ctxt = mExchangeManager->NewContext(session.Value(), prepare.mCasePairingSession);
VerifyOrReturn(ctxt != nullptr);

Transport::FabricInfo * fabric = mFabricsTable->FindFabricWithIndex(mFabricIndex);
VerifyOrReturn(fabric != nullptr);
CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, fabric, prepare.mBuilder.GetPeerNodeId(),
Expand Down
13 changes: 9 additions & 4 deletions src/controller/CHIPDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,17 @@ CHIP_ERROR Device::WarmupCASESession()
VerifyOrReturnError(mDeviceOperationalCertProvisioned, CHIP_ERROR_INCORRECT_STATE);
VerifyOrReturnError(mState == ConnectionState::NotConnected, CHIP_NO_ERROR);

Messaging::ExchangeContext * exchange =
mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mCASESession);
// Create a UnauthenticatedSession for CASE pairing.
// Don't use mSecureSession here, because mSecureSession is the secure session.
Optional<SessionHandle> session = mSessionManager->CreateUnauthenticatedSession(mDeviceAddress);
if (!session.HasValue())
{
return CHIP_ERROR_NO_MEMORY;
}
Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(session.Value(), &mCASESession);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL);

ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager()));
mCASESession.MessageDispatch().SetPeerAddress(mDeviceAddress);
ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager));

uint16_t keyID = 0;
ReturnErrorOnFailure(mIDAllocator->Allocate(keyID));
Expand Down
9 changes: 6 additions & 3 deletions src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam
Transport::PeerAddress peerAddress = Transport::PeerAddress::UDP(Inet::IPAddress::Any);

Messaging::ExchangeContext * exchangeCtxt = nullptr;
Optional<SessionHandle> session;

uint16_t keyID = 0;

Expand Down Expand Up @@ -855,9 +856,8 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam

mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle);

err = mPairingSession.MessageDispatch().Init(mTransportMgr);
err = mPairingSession.MessageDispatch().Init(mSessionMgr);
SuccessOrExit(err);
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());

device->Init(GetControllerDeviceInitParams(), mListenPort, remoteDeviceId, peerAddress, fabric->GetFabricIndex());

Expand All @@ -883,7 +883,10 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam
}
}
#endif
exchangeCtxt = mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mPairingSession);
session = mSessionMgr->CreateUnauthenticatedSession(params.GetPeerAddress());
VerifyOrExit(session.HasValue(), CHIP_ERROR_NO_MEMORY);

exchangeCtxt = mExchangeMgr->NewContext(session.Value(), &mPairingSession);
VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL);

err = mIDAllocator.Allocate(keyID);
Expand Down
12 changes: 12 additions & 0 deletions src/lib/core/CHIPConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,18 @@
#define CHIP_CONFIG_ENABLE_IFJ_SERVICE_FABRIC_JOIN 0
#endif // CHIP_CONFIG_ENABLE_IFJ_SERVICE_FABRIC_JOIN

/**
* @def CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE
*
* @brief Define the size of the pool used for tracking CHIP unauthenticated
* states. The entries in the pool are automatically rotated by LRU. The size
* of the pool limits how many PASE and CASE pairing sessions can be processed
* simultaneously.
*/
#ifndef CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE
#define CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE 4
#endif // CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE

/**
* @def CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE
*
Expand Down
3 changes: 3 additions & 0 deletions src/lib/core/CHIPError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,9 @@ bool FormatCHIPError(char * buf, uint16_t bufSize, CHIP_ERROR err)
case CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED.AsInteger():
desc = "Duplicate message received";
break;
case CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW.AsInteger():
desc = "Message id out of window";
break;
}
#endif // !CHIP_CONFIG_SHORT_ERROR_STR

Expand Down
8 changes: 8 additions & 0 deletions src/lib/core/CHIPError.h
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,14 @@ using CHIP_ERROR = ::chip::ChipError;
*/
#define CHIP_ERROR_FABRIC_MISMATCH_ON_ICA CHIP_CORE_ERROR(0xc6)

/**
* @def CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW
*
* @brief
* The message id of the received message is out of receiving window
*/
#define CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW CHIP_CORE_ERROR(0xc7)

/**
* @}
*/
Expand Down
4 changes: 0 additions & 4 deletions src/lib/core/InPlace.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
*/
#pragma once

#include <core/CHIPCore.h>
#include <lib/core/InPlace.h>
#include <lib/support/Variant.h>

namespace chip {

/// InPlace is disambiguation tags that can be passed to the constructors to indicate that the contained object should be
Expand Down
17 changes: 15 additions & 2 deletions src/lib/support/Pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ template <class T, size_t N>
class BitMapObjectPool : public StaticAllocatorBitmap
{
public:
BitMapObjectPool() : StaticAllocatorBitmap(mMemory, mUsage, N, sizeof(T)) {}
BitMapObjectPool() : StaticAllocatorBitmap(mData.mMemory, mUsage, N, sizeof(T)) {}

static size_t Size() { return N; }

Expand All @@ -110,6 +110,13 @@ class BitMapObjectPool : public StaticAllocatorBitmap
Deallocate(element);
}

template <typename... Args>
void ResetObject(T * element, Args &&... args)
{
element->~T();
new (element) T(std::forward<Args>(args)...);
}

/**
* @brief
* Run a functor for each active object in the pool
Expand Down Expand Up @@ -144,7 +151,13 @@ class BitMapObjectPool : public StaticAllocatorBitmap
};

std::atomic<tBitChunkType> mUsage[(N + kBitChunkSize - 1) / kBitChunkSize];
alignas(alignof(T)) uint8_t mMemory[N * sizeof(T)];
union Data
{
Data() {}
~Data() {}
alignas(alignof(T)) uint8_t mMemory[N * sizeof(T)];
T mMemoryViewForDebug[N]; // Just for debugger
} mData;
};

} // namespace chip
9 changes: 7 additions & 2 deletions src/lib/support/ReferenceCountedHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,19 @@ class ReferenceCountedHandle
explicit ReferenceCountedHandle(Target & target) : mTarget(target) { mTarget.Retain(); }
~ReferenceCountedHandle() { mTarget.Release(); }

ReferenceCountedHandle(const ReferenceCountedHandle & that) = delete;
ReferenceCountedHandle(const ReferenceCountedHandle & that) : mTarget(that.mTarget) { mTarget.Retain(); }

ReferenceCountedHandle(ReferenceCountedHandle && that) : mTarget(that.mTarget) { mTarget.Retain(); }

ReferenceCountedHandle & operator=(const ReferenceCountedHandle & that) = delete;
ReferenceCountedHandle(ReferenceCountedHandle && that) = delete;
ReferenceCountedHandle & operator=(ReferenceCountedHandle && that) = delete;

bool operator==(const ReferenceCountedHandle & that) const { return &mTarget == &that.mTarget; }
bool operator!=(const ReferenceCountedHandle & that) const { return !(*this == that); }

Target * operator->() { return &mTarget; }
Target & Get() const { return mTarget; }

private:
Target & mTarget;
};
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ApplicationExchangeDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ CHIP_ERROR ApplicationExchangeDispatch::PrepareMessage(SessionHandle session, Pa
System::PacketBufferHandle && message,
EncryptedPacketBufferHandle & preparedMessage)
{
return mSessionMgr->BuildEncryptedMessagePayload(session, payloadHeader, std::move(message), preparedMessage);
return mSessionMgr->PrepareMessage(session, payloadHeader, std::move(message), preparedMessage);
}

CHIP_ERROR ApplicationExchangeDispatch::SendPreparedMessage(SessionHandle session,
Expand Down
3 changes: 3 additions & 0 deletions src/messaging/ExchangeMessageDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ class ExchangeMessageDispatch : public ReferenceCounted<ExchangeMessageDispatch>

protected:
virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0;

// TODO: remove IsReliableTransmissionAllowed, this function should be provided over session.
virtual bool IsReliableTransmissionAllowed() const { return true; }

virtual bool IsEncryptionRequired() const { return true; }
};

Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ReliableMessageMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ void ReliableMessageMgr::ClearRetransTable(RetransTableEntry & rEntry)
// Expire any virtual ticks that have expired so all wakeup sources reflect the current time
ExpireTicks();

rEntry.rc->ReleaseContext();
rEntry.rc->SetOccupied(false);
rEntry.rc->ReleaseContext();
rEntry.rc = nullptr;

// Clear all other fields
Expand Down
19 changes: 15 additions & 4 deletions src/messaging/tests/MessagingContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ CHIP_ERROR MessagingContext::Init(nlTestSuite * suite, TransportMgrBase * transp
ReturnErrorOnFailure(mExchangeManager.Init(&mSecureSessionMgr));
ReturnErrorOnFailure(mMessageCounterManager.Init(&mExchangeManager));

ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(mPeer, GetDestinationNodeId(), &mPairingLocalToPeer,
SecureSession::SessionRole::kInitiator, mSrcFabricIndex));
ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(Optional<Transport::PeerAddress>::Value(mPeerAddress), GetDestinationNodeId(),
&mPairingLocalToPeer, SecureSession::SessionRole::kInitiator,
mSrcFabricIndex));

return mSecureSessionMgr.NewPairing(mPeer, GetSourceNodeId(), &mPairingPeerToLocal, SecureSession::SessionRole::kResponder,
mDestFabricIndex);
return mSecureSessionMgr.NewPairing(Optional<Transport::PeerAddress>::Value(mLocalAddress), GetSourceNodeId(),
&mPairingPeerToLocal, SecureSession::SessionRole::kResponder, mDestFabricIndex);
}

// Shutdown all layers, finalize operations
Expand All @@ -67,6 +68,16 @@ SessionHandle MessagingContext::GetSessionPeerToLocal()
return SessionHandle(GetSourceNodeId(), GetPeerKeyId(), GetLocalKeyId(), mDestFabricIndex);
}

Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToPeer(Messaging::ExchangeDelegate * delegate)
{
return mExchangeManager.NewContext(mSecureSessionMgr.CreateUnauthenticatedSession(mPeerAddress).Value(), delegate);
}

Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToLocal(Messaging::ExchangeDelegate * delegate)
{
return mExchangeManager.NewContext(mSecureSessionMgr.CreateUnauthenticatedSession(mLocalAddress).Value(), delegate);
}

Messaging::ExchangeContext * MessagingContext::NewExchangeToPeer(Messaging::ExchangeDelegate * delegate)
{
// TODO: temprary create a SessionHandle from node id, will be fix in PR 3602
Expand Down
9 changes: 7 additions & 2 deletions src/messaging/tests/MessagingContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class MessagingContext
{
public:
MessagingContext() :
mInitialized(false), mPeer(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)),
mInitialized(false), mLocalAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)),
mPeerAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)),
mPairingPeerToLocal(GetLocalKeyId(), GetPeerKeyId()), mPairingLocalToPeer(GetPeerKeyId(), GetLocalKeyId())
{}
~MessagingContext() { VerifyOrDie(mInitialized == false); }
Expand Down Expand Up @@ -80,6 +81,9 @@ class MessagingContext
SessionHandle GetSessionLocalToPeer();
SessionHandle GetSessionPeerToLocal();

Messaging::ExchangeContext * NewUnauthenticatedExchangeToPeer(Messaging::ExchangeDelegate * delegate);
Messaging::ExchangeContext * NewUnauthenticatedExchangeToLocal(Messaging::ExchangeDelegate * delegate);

Messaging::ExchangeContext * NewExchangeToPeer(Messaging::ExchangeDelegate * delegate);
Messaging::ExchangeContext * NewExchangeToLocal(Messaging::ExchangeDelegate * delegate);

Expand All @@ -98,7 +102,8 @@ class MessagingContext
NodeId mDestinationNodeId = 111222333;
uint16_t mLocalKeyId = 1;
uint16_t mPeerKeyId = 2;
Optional<Transport::PeerAddress> mPeer;
Transport::PeerAddress mLocalAddress;
Transport::PeerAddress mPeerAddress;
SecurePairingUsingTestSecret mPairingPeerToLocal;
SecurePairingUsingTestSecret mPairingLocalToPeer;
Transport::FabricTable mFabrics;
Expand Down
Loading

0 comments on commit 030f20d

Please sign in to comment.