Skip to content

Commit

Permalink
Add support for CASE session caching for session resume use cases
Browse files Browse the repository at this point in the history
- Add tests for the CASE session cache

- Update the CASESessionSerializable struct to have only necessary members
  and rename the struct and APIs appropriately
  • Loading branch information
nivi-apple committed Nov 19, 2021
1 parent 5bad4ed commit bbd68f5
Show file tree
Hide file tree
Showing 9 changed files with 486 additions and 59 deletions.
10 changes: 10 additions & 0 deletions src/lib/core/CHIPConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -2700,6 +2700,16 @@ extern const char CHIP_NON_PRODUCTION_MARKER[];
#define CHIP_CONFIG_MAX_SESSION_RELEASE_DELEGATES 2
#endif

/**
* @def CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE
*
* @brief
* Maximum number of CASE sessions that a device caches, that can be resumed
*/
#ifndef CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE
#define CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE 4
#endif

/**
* @}
*/
2 changes: 2 additions & 0 deletions src/protocols/secure_channel/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ static_library("secure_channel") {
"CASEServer.h",
"CASESession.cpp",
"CASESession.h",
"CASESessionCache.cpp",
"CASESessionCache.h",
"PASESession.cpp",
"PASESession.h",
"RendezvousParameters.h",
Expand Down
74 changes: 35 additions & 39 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ constexpr uint8_t kTBEData3_Nonce[] =
constexpr size_t kTBEDataNonceLength = sizeof(kTBEData2_Nonce);
static_assert(sizeof(kTBEData2_Nonce) == sizeof(kTBEData3_Nonce), "TBEData2_Nonce and TBEData3_Nonce must be same size");

constexpr uint8_t kCASESessionVersion = 1;

enum
{
kTag_TBEData_SenderNOC = 1,
Expand Down Expand Up @@ -122,86 +120,84 @@ void CASESession::CloseExchange()
}
}

CHIP_ERROR CASESession::Serialize(CASESessionSerialized & output)
CHIP_ERROR CASESession::Serialize(CacheableCASEParameters & output)
{
uint16_t serializedLen = 0;
CASESessionSerializable serializable;
CASESessionCachable cachableSession;

VerifyOrReturnError(BASE64_ENCODED_LEN(sizeof(serializable)) <= sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(BASE64_ENCODED_LEN(sizeof(cachableSession)) <= sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT);

ReturnErrorOnFailure(ToSerializable(serializable));
ReturnErrorOnFailure(ToCachable(cachableSession));

serializedLen = chip::Base64Encode(Uint8::to_const_uchar(reinterpret_cast<uint8_t *>(&serializable)),
static_cast<uint16_t>(sizeof(serializable)), Uint8::to_char(output.inner));
serializedLen = chip::Base64Encode(Uint8::to_const_uchar(reinterpret_cast<uint8_t *>(&cachableSession)),
static_cast<uint16_t>(sizeof(cachableSession)), Uint8::to_char(output.inner));
VerifyOrReturnError(serializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(serializedLen < sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT);
output.inner[serializedLen] = '\0';

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::Deserialize(CASESessionSerialized & input)
CHIP_ERROR CASESession::Deserialize(CacheableCASEParameters & input)
{
CASESessionSerializable serializable;
size_t maxlen = BASE64_ENCODED_LEN(sizeof(serializable));
CASESessionCachable cachableSession;
size_t maxlen = BASE64_ENCODED_LEN(sizeof(cachableSession));
size_t len = strnlen(Uint8::to_char(input.inner), maxlen);
uint16_t deserializedLen = 0;

VerifyOrReturnError(len < sizeof(CASESessionSerialized), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(len < sizeof(CacheableCASEParameters), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(CanCastTo<uint16_t>(len), CHIP_ERROR_INVALID_ARGUMENT);

memset(&serializable, 0, sizeof(serializable));
memset(&cachableSession, 0, sizeof(cachableSession));
deserializedLen =
Base64Decode(Uint8::to_const_char(input.inner), static_cast<uint16_t>(len), Uint8::to_uchar((uint8_t *) &serializable));
Base64Decode(Uint8::to_const_char(input.inner), static_cast<uint16_t>(len), Uint8::to_uchar((uint8_t *) &cachableSession));

VerifyOrReturnError(deserializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(deserializedLen <= sizeof(serializable), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(deserializedLen <= sizeof(cachableSession), CHIP_ERROR_INVALID_ARGUMENT);

ReturnErrorOnFailure(FromSerializable(serializable));
ReturnErrorOnFailure(FromCachable(cachableSession));

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::ToSerializable(CASESessionSerializable & serializable)
CHIP_ERROR CASESession::ToCachable(CASESessionCachable & cachableSession)
{
const NodeId peerNodeId = GetPeerNodeId();
VerifyOrReturnError(CanCastTo<uint16_t>(mSharedSecret.Length()), CHIP_ERROR_INTERNAL);
VerifyOrReturnError(CanCastTo<uint16_t>(sizeof(mMessageDigest)), CHIP_ERROR_INTERNAL);
VerifyOrReturnError(CanCastTo<uint64_t>(peerNodeId), CHIP_ERROR_INTERNAL);

memset(&serializable, 0, sizeof(serializable));
serializable.mSharedSecretLen = LittleEndian::HostSwap16(static_cast<uint16_t>(mSharedSecret.Length()));
serializable.mMessageDigestLen = LittleEndian::HostSwap16(static_cast<uint16_t>(sizeof(mMessageDigest)));
serializable.mVersion = kCASESessionVersion;
serializable.mPeerNodeId = LittleEndian::HostSwap64(peerNodeId);
serializable.mLocalSessionId = LittleEndian::HostSwap16(GetLocalSessionId());
serializable.mPeerSessionId = LittleEndian::HostSwap16(GetPeerSessionId());
memset(&cachableSession, 0, sizeof(cachableSession));
cachableSession.mMessageDigestLen = LittleEndian::HostSwap16(static_cast<uint16_t>(sizeof(mMessageDigest)));
cachableSession.mSharedSecretLen = LittleEndian::HostSwap16(static_cast<uint16_t>(mSharedSecret.Length()));
cachableSession.mPeerNodeId = LittleEndian::HostSwap64(peerNodeId);
// TODO: Get the fabric index
cachableSession.mLocalFabricIndex = 0;
cachableSession.mSessionSetupTimeStamp = LittleEndian::HostSwap64(mSessionSetupTimeStamp);

memcpy(serializable.mResumptionId, mResumptionId, sizeof(mResumptionId));
memcpy(serializable.mSharedSecret, mSharedSecret, mSharedSecret.Length());
memcpy(serializable.mMessageDigest, mMessageDigest, sizeof(mMessageDigest));
memcpy(cachableSession.mResumptionId, mResumptionId, sizeof(mResumptionId));
memcpy(cachableSession.mSharedSecret, mSharedSecret, mSharedSecret.Length());
memcpy(cachableSession.mMessageDigest, mMessageDigest, sizeof(mMessageDigest));

return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::FromSerializable(const CASESessionSerializable & serializable)
CHIP_ERROR CASESession::FromCachable(const CASESessionCachable & cachableSession)
{
VerifyOrReturnError(serializable.mVersion == kCASESessionVersion, CHIP_ERROR_VERSION_MISMATCH);

uint16_t length = LittleEndian::HostSwap16(serializable.mSharedSecretLen);
uint16_t length = LittleEndian::HostSwap16(cachableSession.mSharedSecretLen);
ReturnErrorOnFailure(mSharedSecret.SetLength(static_cast<size_t>(length)));
memset(mSharedSecret, 0, sizeof(mSharedSecret.Capacity()));
memcpy(mSharedSecret, serializable.mSharedSecret, length);
memcpy(mSharedSecret, cachableSession.mSharedSecret, length);

length = LittleEndian::HostSwap16(serializable.mMessageDigestLen);
length = LittleEndian::HostSwap16(cachableSession.mMessageDigestLen);
VerifyOrReturnError(length <= sizeof(mMessageDigest), CHIP_ERROR_INVALID_ARGUMENT);
memcpy(mMessageDigest, serializable.mMessageDigest, length);
memcpy(mMessageDigest, cachableSession.mMessageDigest, length);

SetPeerNodeId(LittleEndian::HostSwap64(serializable.mPeerNodeId));
SetLocalSessionId(LittleEndian::HostSwap16(serializable.mLocalSessionId));
SetPeerSessionId(LittleEndian::HostSwap16(serializable.mPeerSessionId));
SetPeerNodeId(LittleEndian::HostSwap64(cachableSession.mPeerNodeId));
SetSessionTimeStamp(LittleEndian::HostSwap64(cachableSession.mSessionSetupTimeStamp));
// TODO: Set the fabric index correctly
mLocalFabricIndex = cachableSession.mLocalFabricIndex;

memcpy(mResumptionId, serializable.mResumptionId, sizeof(mResumptionId));
memcpy(mResumptionId, cachableSession.mResumptionId, sizeof(mResumptionId));

const ByteSpan * ipkListSpan = GetIPKList();
VerifyOrReturnError(ipkListSpan->size() == sizeof(mIPK), CHIP_ERROR_INVALID_ARGUMENT);
Expand Down
35 changes: 21 additions & 14 deletions src/protocols/secure_channel/CASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,21 @@ constexpr size_t kCASEResumptionIDSize = 16;
#define CASE_EPHEMERAL_KEY 0xCA5EECD0
#endif

struct CASESessionSerialized;
struct CacheableCASEParameters;

struct CASESessionSerializable
struct CASESessionCachable
{
uint8_t mVersion;
uint16_t mSharedSecretLen;
uint8_t mSharedSecret[Crypto::kMax_ECDH_Secret_Length];
uint16_t mMessageDigestLen;
uint8_t mMessageDigest[Crypto::kSHA256_Hash_Length];
FabricIndex mLocalFabricIndex;
NodeId mPeerNodeId;
uint16_t mLocalSessionId;
uint16_t mPeerSessionId;
// TODO: Use these once Auth Tags are supported in CHIP
uint32_t mAuthTag1;
uint32_t mAuthTag2;
uint8_t mResumptionId[kCASEResumptionIDSize];
uint64_t mSessionSetupTimeStamp;
};

class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public PairingSession
Expand Down Expand Up @@ -155,22 +157,22 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin
/**
* @brief Serialize the Pairing Session to a string.
**/
CHIP_ERROR Serialize(CASESessionSerialized & output);
CHIP_ERROR Serialize(CacheableCASEParameters & output);

/**
* @brief Deserialize the Pairing Session from the string.
**/
CHIP_ERROR Deserialize(CASESessionSerialized & input);
CHIP_ERROR Deserialize(CacheableCASEParameters & input);

/**
* @brief Serialize the CASESession to the given serializable data structure for secure pairing
* @brief Serialize the CASESession to the given cachableSession data structure for secure pairing
**/
CHIP_ERROR ToSerializable(CASESessionSerializable & output);
CHIP_ERROR ToCachable(CASESessionCachable & output);

/**
* @brief Reconstruct secure pairing class from the serializable data structure.
* @brief Reconstruct secure pairing class from the cachableSession data structure.
**/
CHIP_ERROR FromSerializable(const CASESessionSerializable & output);
CHIP_ERROR FromCachable(const CASESessionCachable & output);

SessionEstablishmentExchangeDispatch & MessageDispatch() { return mMessageDispatch; }

Expand Down Expand Up @@ -276,6 +278,9 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin

State mState;

uint8_t mLocalFabricIndex = 0;
uint64_t mSessionSetupTimeStamp = 0;

protected:
bool mCASESessionEstablished = false;

Expand All @@ -289,12 +294,14 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin
return ipkListSpan;
}
virtual size_t GetIPKListEntries() const { return 1; }

void SetSessionTimeStamp(uint64_t timestamp) { mSessionSetupTimeStamp = timestamp; }
};

typedef struct CASESessionSerialized
typedef struct CacheableCASEParameters
{
// Extra uint64_t to account for padding bytes (NULL termination, and some decoding overheads)
uint8_t inner[BASE64_ENCODED_LEN(sizeof(CASESessionSerializable) + sizeof(uint64_t))];
} CASESessionSerialized;
uint8_t inner[BASE64_ENCODED_LEN(sizeof(CASESessionCachable) + sizeof(uint64_t))];
} CacheableCASEParameters;

} // namespace chip
105 changes: 105 additions & 0 deletions src/protocols/secure_channel/CASESessionCache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
*
* Copyright (c) 2021 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <protocols/secure_channel/CASESessionCache.h>

namespace chip {

CASESessionCache::CASESessionCache() {}

CASESessionCache::~CASESessionCache()
{
mCachePool.ForEachActiveObject([&](auto * ec) {
mCachePool.ReleaseObject(ec);
return true;
});
}

CASESessionCachable * CASESessionCache::GetLRUSession()
{
uint64_t minTimeStamp = UINT64_MAX;
CASESessionCachable * lruSession = nullptr;
mCachePool.ForEachActiveObject([&](auto * ec) {
if (minTimeStamp > ec->mSessionSetupTimeStamp)
{
minTimeStamp = ec->mSessionSetupTimeStamp;
lruSession = ec;
}
return true;
});
return lruSession;
}

CHIP_ERROR CASESessionCache::Add(CASESessionCachable & cachableSession)
{
// It's not an error if a device doesn't have cache for storing the sessions.
VerifyOrReturnError(mCachePool.Capacity() > 0, CHIP_NO_ERROR);

// If the cache is full, get the least recently used session index and release that.
if (mCachePool.Exhausted())
{
mCachePool.ReleaseObject(GetLRUSession());
}

mCachePool.CreateObject(cachableSession);
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESessionCache::Remove(ResumptionID resumptionID)
{
CHIP_ERROR err = CHIP_NO_ERROR;
CASESession session;
mCachePool.ForEachActiveObject([&](auto * ec) {
if (resumptionID.data_equal(ResumptionID(ec->mResumptionId)))
{
mCachePool.ReleaseObject(ec);
}
return true;
});

return err;
}

CHIP_ERROR CASESessionCache::Get(ResumptionID resumptionID, CASESessionCachable & outSessionCachable)
{
CHIP_ERROR err = CHIP_NO_ERROR;
bool found = false;
mCachePool.ForEachActiveObject([&](auto * ec) {
if (resumptionID.data_equal(ResumptionID(ec->mResumptionId)))
{
found = true;
outSessionCachable = *ec;
return false;
}
return true;
});

if (!found)
{
err = CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND;
}

return err;
}

CHIP_ERROR CASESessionCache::Get(const PeerId & peer, CASESessionCachable & outSessionCachable)
{
// TODO: Implement this based on peer id
return CHIP_NO_ERROR;
}

} // namespace chip
44 changes: 44 additions & 0 deletions src/protocols/secure_channel/CASESessionCache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
*
* Copyright (c) 2021 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <lib/core/CHIPError.h>
#include <lib/core/PeerId.h>
#include <protocols/secure_channel/CASESession.h>

namespace chip {

using ResumptionID = FixedByteSpan<kCASEResumptionIDSize>;

class CASESessionCache
{
public:
CASESessionCache();
virtual ~CASESessionCache();

CHIP_ERROR Add(CASESessionCachable & cachableSession);
CHIP_ERROR Remove(ResumptionID resumptionID);
CHIP_ERROR Get(ResumptionID resumptionID, CASESessionCachable & outCachableSession);
CHIP_ERROR Get(const PeerId & peer, CASESessionCachable & outCachableSession);

private:
BitMapObjectPool<CASESessionCachable, CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE> mCachePool;
CASESessionCachable * GetLRUSession();
};

} // namespace chip
1 change: 1 addition & 0 deletions src/protocols/secure_channel/tests/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ chip_test_suite("tests") {

test_sources = [
"TestCASESession.cpp",
"TestCASESessionCache.cpp",

# TODO - Fix Message Counter Sync to use group key
# "TestMessageCounterManager.cpp",
Expand Down
Loading

0 comments on commit bbd68f5

Please sign in to comment.