Skip to content

Commit

Permalink
[7.x][ML] Provide factory setup for creating models (elastic#1527) (e…
Browse files Browse the repository at this point in the history
…lastic#1532)

Move boilerplate code for creating models to a base class method.

This goes some way to reducing duplicated code and standardizing how models are created in the tests.

Backports elastic#1527
  • Loading branch information
edsavage authored Oct 13, 2020
1 parent 0cd1feb commit c6200ea
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 495 deletions.
86 changes: 37 additions & 49 deletions lib/model/unittest/CCountingModelTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using namespace model;

class CTestFixture : public CModelTestFixtureBase {
protected:
SModelParams::TStrDetectionRulePr
static SModelParams::TStrDetectionRulePr
makeScheduledEvent(const std::string& description, double start, double end) {
CRuleCondition conditionGte;
conditionGte.appliesTo(CRuleCondition::E_Time);
Expand All @@ -50,6 +50,13 @@ class CTestFixture : public CModelTestFixtureBase {
SModelParams::TStrDetectionRulePr event = std::make_pair(description, rule);
return event;
}

void makeModel(const SModelParams& params,
const model_t::TFeatureVec& features,
core_t::TTime startTime) {
this->makeModelT<CCountingModelFactory>(
params, features, startTime, model_t::E_Counting, m_Gatherer, m_Model);
}
};

BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
Expand All @@ -66,14 +73,11 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {

// Model where gap is not skipped
{
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
CModelFactory::TDataGathererPtr gathererNoGap(
factory.makeDataGatherer(gathererNoGapInitData));
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererNoGap));
CModelFactory::SModelInitializationData modelNoGapInitData(gathererNoGap);
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
CCountingModel* modelNoGap =
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());
CModelFactory::TDataGathererPtr gathererNoGap;
CModelFactory::TModelPtr modelNoGap;
this->makeModelT<CCountingModelFactory>(
params, features, startTime, model_t::E_Counting, gathererNoGap, modelNoGap);
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", gathererNoGap));

// |2|2|0|0|1| -> 1.0 mean count
this->addArrival(*gathererNoGap, 100, "p");
Expand All @@ -90,15 +94,12 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {

// Model where gap is skipped
{
CModelFactory::SGathererInitializationData gathererWithGapInitData(startTime);
CModelFactory::TDataGathererPtr gathererWithGap(
factory.makeDataGatherer(gathererWithGapInitData));
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererWithGap));
CModelFactory::SModelInitializationData modelWithGapInitData(gathererWithGap);
CAnomalyDetectorModel::TModelPtr modelHolderWithGap(
factory.makeModel(modelWithGapInitData));
CCountingModel* modelWithGap =
dynamic_cast<CCountingModel*>(modelHolderWithGap.get());
CModelFactory::TDataGathererPtr gathererWithGap;
CModelFactory::TModelPtr modelWithGap;
this->makeModelT<CCountingModelFactory>(params, features, startTime,
model_t::E_Counting,
gathererWithGap, modelWithGap);
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", gathererWithGap));

// |2|2|0|0|1|
// |2|X|X|X|1| -> 1.5 mean count where X means skipped bucket
Expand All @@ -109,7 +110,7 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
this->addArrival(*gathererWithGap, 280, "p");
modelWithGap->skipSampling(500);
modelWithGap->prune(maxAgeBuckets);
BOOST_REQUIRE_EQUAL(std::size_t(1), gathererWithGap->numberActivePeople());
BOOST_REQUIRE_EQUAL(1, gathererWithGap->numberActivePeople());
this->addArrival(*gathererWithGap, 500, "p");
modelWithGap->sample(500, 600, m_ResourceMonitor);

Expand Down Expand Up @@ -137,14 +138,10 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {
factory.features(features);

{
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
this->addArrival(*gatherer, 200, "p");

CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
CCountingModel* modelNoGap =
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());
this->makeModel(params, features, startTime);
CCountingModel* modelNoGap = dynamic_cast<CCountingModel*>(m_Model.get());
BOOST_TEST_REQUIRE(modelNoGap);
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", m_Gatherer));

SModelParams::TStrDetectionRulePrVec matchedEvents =
modelNoGap->checkScheduledEvents(50);
Expand Down Expand Up @@ -186,14 +183,10 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {

// Test sampleBucketStatistics
{
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
this->addArrival(*gatherer, 100, "p");

CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
CCountingModel* modelNoGap =
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());
this->makeModel(params, features, startTime);
CCountingModel* modelNoGap = dynamic_cast<CCountingModel*>(m_Model.get());
BOOST_TEST_REQUIRE(modelNoGap);
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", m_Gatherer));

// There are no events at this time
modelNoGap->sampleBucketStatistics(0, 100, m_ResourceMonitor);
Expand Down Expand Up @@ -226,18 +219,13 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {

SModelParams params(bucketLength);
params.s_DecayRate = 0.001;
auto interimBucketCorrector = std::make_shared<CInterimBucketCorrector>(bucketLength);
CCountingModelFactory factory(params, interimBucketCorrector);
model_t::TFeatureVec features{model_t::E_IndividualCountByBucketAndPerson};
factory.features(features);

CModelFactory::SGathererInitializationData gathererInitData(time);
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererInitData));
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p1", gatherer));
BOOST_REQUIRE_EQUAL(std::size_t(1), this->addPerson("p2", gatherer));
CModelFactory::SModelInitializationData modelInitData(gatherer);
CAnomalyDetectorModel::TModelPtr modelHolder(factory.makeModel(modelInitData));
CCountingModel* model{dynamic_cast<CCountingModel*>(modelHolder.get())};
this->makeModel(params, {model_t::E_IndividualCountByBucketAndPerson}, time);
CCountingModel* model = dynamic_cast<CCountingModel*>(m_Model.get());
BOOST_TEST_REQUIRE(model);

BOOST_REQUIRE_EQUAL(0, this->addPerson("p1", m_Gatherer));
BOOST_REQUIRE_EQUAL(1, this->addPerson("p2", m_Gatherer));

test::CRandomNumbers rng;

Expand All @@ -249,7 +237,7 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
std::sort(offsets.begin(), offsets.end());
for (auto offset : offsets) {
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offset),
this->addArrival(*m_Gatherer, time + static_cast<core_t::TTime>(offset),
uniform01[0] < 0.5 ? "p1" : "p2");
}
model->sample(time, time + bucketLength, m_ResourceMonitor);
Expand All @@ -260,11 +248,11 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {

for (std::size_t i = 0u; i < offsets.size(); ++i) {
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offsets[i]),
this->addArrival(*m_Gatherer, time + static_cast<core_t::TTime>(offsets[i]),
uniform01[0] < 0.5 ? "p1" : "p2");
model->sampleBucketStatistics(time, time + bucketLength, m_ResourceMonitor);
BOOST_REQUIRE_EQUAL(static_cast<double>(i + 1) / 10.0,
interimBucketCorrector->completeness());
m_InterimBucketCorrector->completeness());
}
}

Expand Down
Loading

0 comments on commit c6200ea

Please sign in to comment.