Skip to content

Commit

Permalink
[ML] Move model test helper functions to base class (#1523)
Browse files Browse the repository at this point in the history
Move existing model test helper functions to base class and share functionality where possible.
  • Loading branch information
edsavage authored Oct 6, 2020
1 parent 3fc14ae commit 499338d
Show file tree
Hide file tree
Showing 8 changed files with 1,315 additions and 1,462 deletions.
112 changes: 42 additions & 70 deletions lib/model/unittest/CCountingModelTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,53 +29,28 @@ BOOST_AUTO_TEST_SUITE(CCountingModelTest)
using namespace ml;
using namespace model;

namespace {
std::size_t addPerson(const std::string& p,
const CModelFactory::TDataGathererPtr& gatherer,
CResourceMonitor& resourceMonitor) {
CDataGatherer::TStrCPtrVec person;
person.push_back(&p);
CEventData result;
gatherer->processFields(person, result, resourceMonitor);
return *result.personId();
}

void addArrival(CDataGatherer& gatherer,
CResourceMonitor& resourceMonitor,
core_t::TTime time,
const std::string& person) {
CDataGatherer::TStrCPtrVec fieldValues;
fieldValues.push_back(&person);

CEventData eventData;
eventData.time(time);
gatherer.addArrival(fieldValues, eventData, resourceMonitor);
}

SModelParams::TStrDetectionRulePr
makeScheduledEvent(const std::string& description, double start, double end) {
CRuleCondition conditionGte;
conditionGte.appliesTo(CRuleCondition::E_Time);
conditionGte.op(CRuleCondition::E_GTE);
conditionGte.value(start);
CRuleCondition conditionLt;
conditionLt.appliesTo(CRuleCondition::E_Time);
conditionLt.op(CRuleCondition::E_LT);
conditionLt.value(end);

CDetectionRule rule;
rule.action(CDetectionRule::E_SkipModelUpdate);
rule.addCondition(conditionGte);
rule.addCondition(conditionLt);

SModelParams::TStrDetectionRulePr event = std::make_pair(description, rule);
return event;
}

const std::string EMPTY_STRING;
}

class CTestFixture : public CModelTestFixtureBase {};
class CTestFixture : public CModelTestFixtureBase {
protected:
SModelParams::TStrDetectionRulePr
makeScheduledEvent(const std::string& description, double start, double end) {
CRuleCondition conditionGte;
conditionGte.appliesTo(CRuleCondition::E_Time);
conditionGte.op(CRuleCondition::E_GTE);
conditionGte.value(start);
CRuleCondition conditionLt;
conditionLt.appliesTo(CRuleCondition::E_Time);
conditionLt.op(CRuleCondition::E_LT);
conditionLt.value(end);

CDetectionRule rule;
rule.action(CDetectionRule::E_SkipModelUpdate);
rule.addCondition(conditionGte);
rule.addCondition(conditionLt);

SModelParams::TStrDetectionRulePr event = std::make_pair(description, rule);
return event;
}
};

BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
core_t::TTime startTime{100};
Expand All @@ -94,20 +69,20 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
CModelFactory::TDataGathererPtr gathererNoGap(
factory.makeDataGatherer(gathererNoGapInitData));
BOOST_REQUIRE_EQUAL(std::size_t(0), addPerson("p", gathererNoGap, m_ResourceMonitor));
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererNoGap));
CModelFactory::SModelInitializationData modelNoGapInitData(gathererNoGap);
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
CCountingModel* modelNoGap =
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());

// |2|2|0|0|1| -> 1.0 mean count
addArrival(*gathererNoGap, m_ResourceMonitor, 100, "p");
addArrival(*gathererNoGap, m_ResourceMonitor, 110, "p");
this->addArrival(*gathererNoGap, 100, "p");
this->addArrival(*gathererNoGap, 110, "p");
modelNoGap->sample(100, 200, m_ResourceMonitor);
addArrival(*gathererNoGap, m_ResourceMonitor, 250, "p");
addArrival(*gathererNoGap, m_ResourceMonitor, 280, "p");
this->addArrival(*gathererNoGap, 250, "p");
this->addArrival(*gathererNoGap, 280, "p");
modelNoGap->sample(200, 500, m_ResourceMonitor);
addArrival(*gathererNoGap, m_ResourceMonitor, 500, "p");
this->addArrival(*gathererNoGap, 500, "p");
modelNoGap->sample(500, 600, m_ResourceMonitor);

BOOST_REQUIRE_EQUAL(1.0, *modelNoGap->baselineBucketCount(0));
Expand All @@ -118,8 +93,7 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
CModelFactory::SGathererInitializationData gathererWithGapInitData(startTime);
CModelFactory::TDataGathererPtr gathererWithGap(
factory.makeDataGatherer(gathererWithGapInitData));
BOOST_REQUIRE_EQUAL(std::size_t(0),
addPerson("p", gathererWithGap, m_ResourceMonitor));
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererWithGap));
CModelFactory::SModelInitializationData modelWithGapInitData(gathererWithGap);
CAnomalyDetectorModel::TModelPtr modelHolderWithGap(
factory.makeModel(modelWithGapInitData));
Expand All @@ -128,15 +102,15 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {

// |2|2|0|0|1|
// |2|X|X|X|1| -> 1.5 mean count where X means skipped bucket
addArrival(*gathererWithGap, m_ResourceMonitor, 100, "p");
addArrival(*gathererWithGap, m_ResourceMonitor, 110, "p");
this->addArrival(*gathererWithGap, 100, "p");
this->addArrival(*gathererWithGap, 110, "p");
modelWithGap->sample(100, 200, m_ResourceMonitor);
addArrival(*gathererWithGap, m_ResourceMonitor, 250, "p");
addArrival(*gathererWithGap, m_ResourceMonitor, 280, "p");
this->addArrival(*gathererWithGap, 250, "p");
this->addArrival(*gathererWithGap, 280, "p");
modelWithGap->skipSampling(500);
modelWithGap->prune(maxAgeBuckets);
BOOST_REQUIRE_EQUAL(std::size_t(1), gathererWithGap->numberActivePeople());
addArrival(*gathererWithGap, m_ResourceMonitor, 500, "p");
this->addArrival(*gathererWithGap, 500, "p");
modelWithGap->sample(500, 600, m_ResourceMonitor);

BOOST_REQUIRE_EQUAL(1.5, *modelWithGap->baselineBucketCount(0));
Expand Down Expand Up @@ -166,7 +140,7 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
addArrival(*gatherer, m_ResourceMonitor, 200, "p");
this->addArrival(*gatherer, 200, "p");

CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
CCountingModel* modelNoGap =
Expand Down Expand Up @@ -215,7 +189,7 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
addArrival(*gatherer, m_ResourceMonitor, 100, "p");
this->addArrival(*gatherer, 100, "p");

CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
CCountingModel* modelNoGap =
Expand Down Expand Up @@ -259,8 +233,8 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {

CModelFactory::SGathererInitializationData gathererInitData(time);
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererInitData));
BOOST_REQUIRE_EQUAL(std::size_t(0), addPerson("p1", gatherer, m_ResourceMonitor));
BOOST_REQUIRE_EQUAL(std::size_t(1), addPerson("p2", gatherer, m_ResourceMonitor));
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p1", gatherer));
BOOST_REQUIRE_EQUAL(std::size_t(1), this->addPerson("p2", gatherer));
CModelFactory::SModelInitializationData modelInitData(gatherer);
CAnomalyDetectorModel::TModelPtr modelHolder(factory.makeModel(modelInitData));
CCountingModel* model{dynamic_cast<CCountingModel*>(modelHolder.get())};
Expand All @@ -275,9 +249,8 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
std::sort(offsets.begin(), offsets.end());
for (auto offset : offsets) {
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
addArrival(*gatherer, m_ResourceMonitor,
time + static_cast<core_t::TTime>(offset),
uniform01[0] < 0.5 ? "p1" : "p2");
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offset),
uniform01[0] < 0.5 ? "p1" : "p2");
}
model->sample(time, time + bucketLength, m_ResourceMonitor);
}
Expand All @@ -287,9 +260,8 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {

for (std::size_t i = 0u; i < offsets.size(); ++i) {
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
addArrival(*gatherer, m_ResourceMonitor,
time + static_cast<core_t::TTime>(offsets[i]),
uniform01[0] < 0.5 ? "p1" : "p2");
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offsets[i]),
uniform01[0] < 0.5 ? "p1" : "p2");
model->sampleBucketStatistics(time, time + bucketLength, m_ResourceMonitor);
BOOST_REQUIRE_EQUAL(static_cast<double>(i + 1) / 10.0,
interimBucketCorrector->completeness());
Expand Down
Loading

0 comments on commit 499338d

Please sign in to comment.