Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for UDCClientState #18449

Merged
merged 11 commits into from
May 17, 2022
5 changes: 3 additions & 2 deletions src/protocols/user_directed_commissioning/UDCClientState.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ class UDCClientState
size_t GetRotatingIdLength() const { return mRotatingIdLen; }
void SetRotatingId(const uint8_t * rotatingId, size_t rotatingIdLen)
{
memcpy(mRotatingId, rotatingId, rotatingIdLen);
mRotatingIdLen = rotatingIdLen;
size_t maxSize = ArraySize(mRotatingId);
mRotatingIdLen = (maxSize < rotatingIdLen) ? maxSize : rotatingIdLen;
memcpy(mRotatingId, rotatingId, mRotatingIdLen);
}

UDCClientProcessingState GetUDCClientProcessingState() const { return mUDCClientProcessingState; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ void UserDirectedCommissioningServer::OnCommissionableNodeFound(const Dnssd::Dis
client->SetLongDiscriminator(nodeData.commissionData.longDiscriminator);
client->SetVendorId(nodeData.commissionData.vendorId);
client->SetProductId(nodeData.commissionData.productId);
client->SetDeviceName(nodeData.commissionData.deviceName);
client->SetRotatingId(nodeData.commissionData.rotatingId, nodeData.commissionData.rotatingIdLen);

// Call the registered mUserConfirmationProvider, if any.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <nlunit-test.h>

#include <lib/core/CHIPSafeCasts.h>
#include <lib/dnssd/TxtFields.h>
#include <lib/support/BufferWriter.h>
#include <lib/support/CHIPMem.h>
#include <lib/support/CodeUtils.h>
Expand All @@ -15,7 +16,16 @@

using namespace chip;
using namespace chip::Protocols::UserDirectedCommissioning;
using namespace chip::Dnssd;
using namespace chip::Dnssd::Internal;

ByteSpan GetSpan(char * key)
{
size_t len = strlen(key);
// Stop the string from being null terminated to ensure the code makes no assumptions.
key[len] = '1';
return ByteSpan(Uint8::from_char(key), len);
}
class DLL_EXPORT TestCallback : public UserConfirmationProvider, public InstanceNameResolver
{
public:
Expand Down Expand Up @@ -267,6 +277,86 @@ void TestUDCClients(nlTestSuite * inSuite, void * inContext)
NL_TEST_ASSERT(inSuite, (expirationTime - System::Clock::Milliseconds64(1)) < state->GetExpirationTime());
}

void TestUDCClientState(nlTestSuite * inSuite, void * inContext)
{
UDCClients<3> mUdcClients;
const char * instanceName1 = "test1";
Inet::IPAddress address;
Inet::IPAddress::FromString("127.0.0.1", address);
uint16_t port = 333;
uint16_t longDiscriminator = 1234;
uint16_t vendorId = 1111;
uint16_t productId = 2222;
const char * deviceName = "test name";

// Rotating ID is given as up to 50 hex bytes
char rotatingIdString[chip::Dnssd::kMaxRotatingIdLen * 2 + 1];
uint8_t rotatingId[chip::Dnssd::kMaxRotatingIdLen];
size_t rotatingIdLen;
strcpy(rotatingIdString, "92873498273948734534");
GetRotatingDeviceId(GetSpan(rotatingIdString), rotatingId, &rotatingIdLen);

// create a Rotating ID longer than kMaxRotatingIdLen
char rotatingIdLongString[chip::Dnssd::kMaxRotatingIdLen * 4 + 1];
uint8_t rotatingIdLong[chip::Dnssd::kMaxRotatingIdLen * 2];
size_t rotatingIdLongLen;
strcpy(
rotatingIdLongString,
"123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890");

const ByteSpan & value = GetSpan(rotatingIdLongString);
rotatingIdLongLen = Encoding::HexToBytes(reinterpret_cast<const char *>(value.data()), value.size(), rotatingIdLong,
chip::Dnssd::kMaxRotatingIdLen * 2);

NL_TEST_ASSERT(inSuite, rotatingIdLongLen > chip::Dnssd::kMaxRotatingIdLen);

// test base case
UDCClientState * state = mUdcClients.FindUDCClientState(instanceName1);
NL_TEST_ASSERT(inSuite, state == nullptr);

// add a default state
NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == mUdcClients.CreateNewUDCClientState(instanceName1, &state));

// get the state
state = mUdcClients.FindUDCClientState(instanceName1);
NL_TEST_ASSERT(inSuite, nullptr != state);
NL_TEST_ASSERT(inSuite, strcmp(state->GetInstanceName(), instanceName1) == 0);

state->SetPeerAddress(chip::Transport::PeerAddress::UDP(address, port));
NL_TEST_ASSERT(inSuite, port == state->GetPeerAddress().GetPort());

state->SetDeviceName(deviceName);
NL_TEST_ASSERT(inSuite, strcmp(state->GetDeviceName(), deviceName) == 0);

state->SetLongDiscriminator(longDiscriminator);
NL_TEST_ASSERT(inSuite, longDiscriminator == state->GetLongDiscriminator());

state->SetVendorId(vendorId);
NL_TEST_ASSERT(inSuite, vendorId == state->GetVendorId());

state->SetProductId(productId);
NL_TEST_ASSERT(inSuite, productId == state->GetProductId());

state->SetRotatingId(rotatingId, rotatingIdLen);
NL_TEST_ASSERT(inSuite, rotatingIdLen == state->GetRotatingIdLength());

const uint8_t * testRotatingId = state->GetRotatingId();
for (size_t i = 0; i < rotatingIdLen; i++)
{
NL_TEST_ASSERT(inSuite, testRotatingId[i] == rotatingId[i]);
}

state->SetRotatingId(rotatingIdLong, rotatingIdLongLen);

NL_TEST_ASSERT(inSuite, chip::Dnssd::kMaxRotatingIdLen == state->GetRotatingIdLength());

const uint8_t * testRotatingIdLong = state->GetRotatingId();
for (size_t i = 0; i < chip::Dnssd::kMaxRotatingIdLen; i++)
{
NL_TEST_ASSERT(inSuite, testRotatingIdLong[i] == rotatingIdLong[i]);
}
}

// Test Suite

/**
Expand All @@ -280,6 +370,7 @@ static const nlTest sTests[] =
NL_TEST_DEF("TestUDCServerInstanceNameResolver", TestUDCServerInstanceNameResolver),
NL_TEST_DEF("TestUserDirectedCommissioningClientMessage", TestUserDirectedCommissioningClientMessage),
NL_TEST_DEF("TestUDCClients", TestUDCClients),
NL_TEST_DEF("TestUDCClientState", TestUDCClientState),

NL_TEST_SENTINEL()
};
Expand Down