Skip to content

Commit

Permalink
Cover rollover case for AdvanceBy and add unit test to validate behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
mkardous-silabs committed Apr 26, 2024
1 parent de85890 commit 1951690
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 23 deletions.
72 changes: 49 additions & 23 deletions src/lib/support/PersistedCounter.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,43 @@ class PersistedCounter : public MonotonicallyIncreasingCounter<T>
return MonotonicallyIncreasingCounter<T>::Init(startValue);
}

/**
* @brief Increment the counter by N and write to persisted storage if we've completed the current epoch.
*
* @param value value of N
*
* @return Any error returned by a write to persisted storage.
*/
CHIP_ERROR AdvanceBy(T value) override
{
VerifyOrReturnError(mStorage != nullptr, CHIP_ERROR_INCORRECT_STATE);
VerifyOrReturnError(mKey.IsInitialized(), CHIP_ERROR_INCORRECT_STATE);

// If value is 0, we do not need to do anything
VerifyOrReturnError(value > 0, CHIP_NO_ERROR);

// We should update the persisted epoch value if :
// 1- Sum of the current counter and value is greater or equal to the mNextEpoch.
// This is the standard operating case.
// 2- Increasing the current counter by value would cause a roll over. This would cause the current value to be < to the
// mNextEpoch so we force an update.
bool shouldDoEpochUpdate = ((MonotonicallyIncreasingCounter<T>::GetValue() + value) >= mNextEpoch) ||
(MonotonicallyIncreasingCounter<T>::GetValue() > INT_MAX - value);

ReturnErrorOnFailure(MonotonicallyIncreasingCounter<T>::AdvanceBy(value));

return VerifyAndPersistNextEpochStart(MonotonicallyIncreasingCounter<T>::GetValue());
// Since AdvanceBy allows the counter to be increase by an arbitrary value, it is possible that the new counter value is
// greater than mNextEpoch + mEpoch. As such, we want to the next Epoch value to be calculate from the new current value.
if (shouldDoEpochUpdate)
{
PersistAndVerifyNextEpochStart(MonotonicallyIncreasingCounter<T>::GetValue());
}

return CHIP_NO_ERROR;
}

/**
* @brief
* Increment the counter and write to persisted storage if we've completed
* the current epoch.
* @brief Increment the counter and write to persisted storage if we've completed the current epoch.
*
* @return Any error returned by a write to persisted storage.
*/
Expand All @@ -134,29 +157,31 @@ class PersistedCounter : public MonotonicallyIncreasingCounter<T>

ReturnErrorOnFailure(MonotonicallyIncreasingCounter<T>::Advance());

return VerifyAndPersistNextEpochStart(mNextEpoch);
if (MonotonicallyIncreasingCounter<T>::GetValue() >= mNextEpoch)
{
ReturnErrorOnFailure(PersistAndVerifyNextEpochStart(mNextEpoch));
}

return CHIP_NO_ERROR;
}

private:
CHIP_ERROR VerifyAndPersistNextEpochStart(T refEpoch)
CHIP_ERROR PersistAndVerifyNextEpochStart(T refEpoch)
{
// Value advanced past the previously persisted "start point".
// Ensure that a new starting point is persisted.
ReturnErrorOnFailure(PersistNextEpochStart(static_cast<T>(refEpoch + mEpoch)));

// Advancing the epoch should have ensured that the current value is valid
VerifyOrReturnError(static_cast<T>(MonotonicallyIncreasingCounter<T>::GetValue() + mEpoch) == mNextEpoch,
CHIP_ERROR_INTERNAL);

// Previous check did not take into consideration that the counter value can be equal to the max counter value or
// rollover.
// TODO(#33175): PersistedCounter allows rollover so this check is incorrect. We need a Counter class that adequatly
// manages rollover behavior for counters that cannot rollover.
// VerifyOrReturnError(MonotonicallyIncreasingCounter<T>::GetValue() < mNextEpoch, CHIP_ERROR_INTERNAL);

if (MonotonicallyIncreasingCounter<T>::GetValue() >= mNextEpoch)
{
// Value advanced past the previously persisted "start point".
// Ensure that a new starting point is persisted.
ReturnErrorOnFailure(PersistNextEpochStart(static_cast<T>(refEpoch + mEpoch)));

// Advancing the epoch should have ensured that the current value is valid
VerifyOrReturnError(static_cast<T>(MonotonicallyIncreasingCounter<T>::GetValue() + mEpoch) == mNextEpoch,
CHIP_ERROR_INTERNAL);

// Previous check did not take into consideration that the counter value can be equal to the max counter value or
// rollover.
// TODO(#33175): PersistedCounter allows rollover so this check is incorrect. We need a Counter class that adequatly
// manages rollover behavior for counters that cannot rollover.
// VerifyOrReturnError(MonotonicallyIncreasingCounter<T>::GetValue() < mNextEpoch, CHIP_ERROR_INTERNAL);
}
return CHIP_NO_ERROR;
}

Expand All @@ -168,7 +193,8 @@ class PersistedCounter : public MonotonicallyIncreasingCounter<T>
*
* @return Any error returned by a write to persistent storage.
*/
CHIP_ERROR PersistNextEpochStart(T aStartValue)
CHIP_ERROR
PersistNextEpochStart(T aStartValue)
{
mNextEpoch = aStartValue;
#if CHIP_CONFIG_PERSISTED_COUNTER_DEBUG_LOGGING
Expand Down
48 changes: 48 additions & 0 deletions src/lib/support/tests/TestPersistedCounter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,52 @@ TEST(TestPersistedCounter, TestAdvanceByMaxCounterValue)
EXPECT_EQ(counter.GetValue(), 0ULL);
}

TEST(TestPersistedCounter, TestAdvanceByRollover)
{
chip::TestPersistentStorageDelegate storage;
chip::PersistedCounter<uint64_t> counter;

uint64_t epoch = UINT64_MAX / 4;
uint64_t currentEpoch = epoch;
uint64_t current = 0;
uint64_t storedValue = 0;
uint16_t size = sizeof(storedValue);

EXPECT_EQ(counter.Init(&storage, chip::DefaultStorageKeyAllocator::IMEventNumber(), epoch), CHIP_NO_ERROR);
EXPECT_EQ(counter.GetValue(), current);

// Check new Epoch value was persisted
EXPECT_EQ(storage.SyncGetKeyValue(chip::DefaultStorageKeyAllocator::IMEventNumber().KeyName(), &storedValue, size),
CHIP_NO_ERROR);
EXPECT_EQ(sizeof(storedValue), size);
storedValue = Encoding::LittleEndian::HostSwap<uint64_t>(storedValue);
EXPECT_EQ(currentEpoch, storedValue);

// Increase counter to update persisted value
current += (currentEpoch + 100);
EXPECT_EQ(counter.AdvanceBy(currentEpoch + 100), CHIP_NO_ERROR);
EXPECT_EQ(counter.GetValue(), current);

// Check new Epoch value was persisted
currentEpoch = (currentEpoch * 2 + 100);
EXPECT_EQ(storage.SyncGetKeyValue(chip::DefaultStorageKeyAllocator::IMEventNumber().KeyName(), &storedValue, size),
CHIP_NO_ERROR);
EXPECT_EQ(sizeof(storedValue), size);
storedValue = Encoding::LittleEndian::HostSwap<uint64_t>(storedValue);
EXPECT_EQ(currentEpoch, storedValue);

// Force roll over
current += (3 * epoch);
EXPECT_EQ(counter.AdvanceBy((3 * epoch)), CHIP_NO_ERROR);
EXPECT_EQ(counter.GetValue(), current);

// Check new Epoch value was persisted
currentEpoch = current + epoch;
EXPECT_EQ(storage.SyncGetKeyValue(chip::DefaultStorageKeyAllocator::IMEventNumber().KeyName(), &storedValue, size),
CHIP_NO_ERROR);
EXPECT_EQ(sizeof(storedValue), size);
storedValue = Encoding::LittleEndian::HostSwap<uint64_t>(storedValue);
EXPECT_EQ(currentEpoch, storedValue);
}

} // namespace

0 comments on commit 1951690

Please sign in to comment.