-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Provide factory setup for creating models #1527
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -31,7 +31,7 @@ using namespace model; | |||||||||
|
||||||||||
class CTestFixture : public CModelTestFixtureBase { | ||||||||||
protected: | ||||||||||
SModelParams::TStrDetectionRulePr | ||||||||||
static SModelParams::TStrDetectionRulePr | ||||||||||
makeScheduledEvent(const std::string& description, double start, double end) { | ||||||||||
CRuleCondition conditionGte; | ||||||||||
conditionGte.appliesTo(CRuleCondition::E_Time); | ||||||||||
|
@@ -50,6 +50,13 @@ class CTestFixture : public CModelTestFixtureBase { | |||||||||
SModelParams::TStrDetectionRulePr event = std::make_pair(description, rule); | ||||||||||
return event; | ||||||||||
} | ||||||||||
|
||||||||||
void makeModel(const SModelParams& params, | ||||||||||
const model_t::TFeatureVec& features, | ||||||||||
core_t::TTime startTime) { | ||||||||||
this->makeModelT<CCountingModelFactory>( | ||||||||||
params, features, startTime, model_t::E_Counting, m_Gatherer, m_Model); | ||||||||||
} | ||||||||||
}; | ||||||||||
|
||||||||||
BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) { | ||||||||||
|
@@ -66,14 +73,11 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) { | |||||||||
|
||||||||||
// Model where gap is not skipped | ||||||||||
{ | ||||||||||
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime); | ||||||||||
CModelFactory::TDataGathererPtr gathererNoGap( | ||||||||||
factory.makeDataGatherer(gathererNoGapInitData)); | ||||||||||
CModelFactory::TDataGathererPtr gathererNoGap; | ||||||||||
CModelFactory::TModelPtr modelNoGap; | ||||||||||
this->makeModelT<CCountingModelFactory>( | ||||||||||
params, features, startTime, model_t::E_Counting, gathererNoGap, modelNoGap); | ||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererNoGap)); | ||||||||||
CModelFactory::SModelInitializationData modelNoGapInitData(gathererNoGap); | ||||||||||
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData)); | ||||||||||
CCountingModel* modelNoGap = | ||||||||||
dynamic_cast<CCountingModel*>(modelHolderNoGap.get()); | ||||||||||
|
||||||||||
// |2|2|0|0|1| -> 1.0 mean count | ||||||||||
this->addArrival(*gathererNoGap, 100, "p"); | ||||||||||
|
@@ -90,15 +94,12 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) { | |||||||||
|
||||||||||
// Model where gap is skipped | ||||||||||
{ | ||||||||||
CModelFactory::SGathererInitializationData gathererWithGapInitData(startTime); | ||||||||||
CModelFactory::TDataGathererPtr gathererWithGap( | ||||||||||
factory.makeDataGatherer(gathererWithGapInitData)); | ||||||||||
CModelFactory::TDataGathererPtr gathererWithGap; | ||||||||||
CModelFactory::TModelPtr modelWithGap; | ||||||||||
this->makeModelT<CCountingModelFactory>(params, features, startTime, | ||||||||||
model_t::E_Counting, | ||||||||||
gathererWithGap, modelWithGap); | ||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererWithGap)); | ||||||||||
CModelFactory::SModelInitializationData modelWithGapInitData(gathererWithGap); | ||||||||||
CAnomalyDetectorModel::TModelPtr modelHolderWithGap( | ||||||||||
factory.makeModel(modelWithGapInitData)); | ||||||||||
CCountingModel* modelWithGap = | ||||||||||
dynamic_cast<CCountingModel*>(modelHolderWithGap.get()); | ||||||||||
|
||||||||||
// |2|2|0|0|1| | ||||||||||
// |2|X|X|X|1| -> 1.5 mean count where X means skipped bucket | ||||||||||
|
@@ -137,14 +138,10 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) { | |||||||||
factory.features(features); | ||||||||||
|
||||||||||
{ | ||||||||||
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime); | ||||||||||
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData)); | ||||||||||
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer); | ||||||||||
this->addArrival(*gatherer, 200, "p"); | ||||||||||
|
||||||||||
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData)); | ||||||||||
CCountingModel* modelNoGap = | ||||||||||
dynamic_cast<CCountingModel*>(modelHolderNoGap.get()); | ||||||||||
this->makeModel(params, features, startTime); | ||||||||||
CCountingModel* modelNoGap = dynamic_cast<CCountingModel*>(m_Model.get()); | ||||||||||
BOOST_TEST_REQUIRE(modelNoGap); | ||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", m_Gatherer)); | ||||||||||
|
||||||||||
SModelParams::TStrDetectionRulePrVec matchedEvents = | ||||||||||
modelNoGap->checkScheduledEvents(50); | ||||||||||
|
@@ -186,14 +183,10 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) { | |||||||||
|
||||||||||
// Test sampleBucketStatistics | ||||||||||
{ | ||||||||||
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime); | ||||||||||
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData)); | ||||||||||
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer); | ||||||||||
this->addArrival(*gatherer, 100, "p"); | ||||||||||
|
||||||||||
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData)); | ||||||||||
CCountingModel* modelNoGap = | ||||||||||
dynamic_cast<CCountingModel*>(modelHolderNoGap.get()); | ||||||||||
this->makeModel(params, features, startTime); | ||||||||||
CCountingModel* modelNoGap = dynamic_cast<CCountingModel*>(m_Model.get()); | ||||||||||
BOOST_TEST_REQUIRE(modelNoGap); | ||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", m_Gatherer)); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
// There are no events at this time | ||||||||||
modelNoGap->sampleBucketStatistics(0, 100, m_ResourceMonitor); | ||||||||||
|
@@ -226,18 +219,13 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) { | |||||||||
|
||||||||||
SModelParams params(bucketLength); | ||||||||||
params.s_DecayRate = 0.001; | ||||||||||
auto interimBucketCorrector = std::make_shared<CInterimBucketCorrector>(bucketLength); | ||||||||||
CCountingModelFactory factory(params, interimBucketCorrector); | ||||||||||
model_t::TFeatureVec features{model_t::E_IndividualCountByBucketAndPerson}; | ||||||||||
factory.features(features); | ||||||||||
|
||||||||||
CModelFactory::SGathererInitializationData gathererInitData(time); | ||||||||||
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererInitData)); | ||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p1", gatherer)); | ||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(1), this->addPerson("p2", gatherer)); | ||||||||||
CModelFactory::SModelInitializationData modelInitData(gatherer); | ||||||||||
CAnomalyDetectorModel::TModelPtr modelHolder(factory.makeModel(modelInitData)); | ||||||||||
CCountingModel* model{dynamic_cast<CCountingModel*>(modelHolder.get())}; | ||||||||||
this->makeModel(params, {model_t::E_IndividualCountByBucketAndPerson}, time); | ||||||||||
CCountingModel* model = dynamic_cast<CCountingModel*>(m_Model.get()); | ||||||||||
BOOST_TEST_REQUIRE(model); | ||||||||||
|
||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p1", m_Gatherer)); | ||||||||||
BOOST_REQUIRE_EQUAL(std::size_t(1), this->addPerson("p2", m_Gatherer)); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
test::CRandomNumbers rng; | ||||||||||
|
||||||||||
|
@@ -249,7 +237,7 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) { | |||||||||
std::sort(offsets.begin(), offsets.end()); | ||||||||||
for (auto offset : offsets) { | ||||||||||
rng.generateUniformSamples(0.0, 1.0, 1, uniform01); | ||||||||||
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offset), | ||||||||||
this->addArrival(*m_Gatherer, time + static_cast<core_t::TTime>(offset), | ||||||||||
uniform01[0] < 0.5 ? "p1" : "p2"); | ||||||||||
} | ||||||||||
model->sample(time, time + bucketLength, m_ResourceMonitor); | ||||||||||
|
@@ -260,11 +248,11 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) { | |||||||||
|
||||||||||
for (std::size_t i = 0u; i < offsets.size(); ++i) { | ||||||||||
rng.generateUniformSamples(0.0, 1.0, 1, uniform01); | ||||||||||
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offsets[i]), | ||||||||||
this->addArrival(*m_Gatherer, time + static_cast<core_t::TTime>(offsets[i]), | ||||||||||
uniform01[0] < 0.5 ? "p1" : "p2"); | ||||||||||
model->sampleBucketStatistics(time, time + bucketLength, m_ResourceMonitor); | ||||||||||
BOOST_REQUIRE_EQUAL(static_cast<double>(i + 1) / 10.0, | ||||||||||
interimBucketCorrector->completeness()); | ||||||||||
m_InterimBucketCorrector->completeness()); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -187,53 +187,21 @@ class CTestFixture : public CModelTestFixtureBase { | |||||
core_t::TTime startTime, | ||||||
std::size_t numberPeople, | ||||||
const std::string& summaryCountField = EMPTY_STRING) { | ||||||
this->makeModel(params, features, startTime, numberPeople, m_Gatherer, | ||||||
m_Model, summaryCountField); | ||||||
} | ||||||
this->makeModelT<CEventRateModelFactory>(params, features, startTime, | ||||||
model_t::E_EventRateOnline, m_Gatherer, | ||||||
m_Model, {}, summaryCountField); | ||||||
|
||||||
void makeModel(const SModelParams& params, | ||||||
const model_t::TFeatureVec& features, | ||||||
core_t::TTime startTime, | ||||||
std::size_t numberPeople, | ||||||
CModelFactory::TDataGathererPtr& gatherer, | ||||||
CModelFactory::TModelPtr& model, | ||||||
const std::string& summaryCountField = EMPTY_STRING) { | ||||||
if (m_InterimBucketCorrector == nullptr) { | ||||||
m_InterimBucketCorrector = | ||||||
std::make_shared<CInterimBucketCorrector>(params.s_BucketLength); | ||||||
} | ||||||
if (m_Factory == nullptr) { | ||||||
m_Factory.reset(new CEventRateModelFactory( | ||||||
params, m_InterimBucketCorrector, | ||||||
summaryCountField.empty() ? model_t::E_None : model_t::E_Manual, | ||||||
summaryCountField)); | ||||||
m_Factory->features(features); | ||||||
} | ||||||
gatherer.reset(m_Factory->makeDataGatherer({startTime})); | ||||||
model.reset(m_Factory->makeModel({gatherer})); | ||||||
BOOST_TEST_REQUIRE(model); | ||||||
BOOST_REQUIRE_EQUAL(model_t::E_EventRateOnline, model->category()); | ||||||
BOOST_REQUIRE_EQUAL(params.s_BucketLength, model->bucketLength()); | ||||||
for (std::size_t i = 0u; i < numberPeople; ++i) { | ||||||
BOOST_REQUIRE_EQUAL( | ||||||
std::size_t(i), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
this->addPerson("p" + core::CStringUtils::typeToString(i + 1), gatherer)); | ||||||
this->addPerson("p" + core::CStringUtils::typeToString(i + 1), m_Gatherer)); | ||||||
} | ||||||
} | ||||||
|
||||||
protected: | ||||||
using TInterimBucketCorrectorPtr = std::shared_ptr<CInterimBucketCorrector>; | ||||||
using TEventRateModelFactoryPtr = std::shared_ptr<CEventRateModelFactory>; | ||||||
|
||||||
using TDoubleSizeStrTr = core::CTriple<double, std::size_t, std::string>; | ||||||
using TMinAccumulator = maths::CBasicStatistics::COrderStatisticsHeap<TDoubleSizeStrTr>; | ||||||
using TMinAccumulatorVec = std::vector<TMinAccumulator>; | ||||||
|
||||||
protected: | ||||||
TInterimBucketCorrectorPtr m_InterimBucketCorrector; | ||||||
TEventRateModelFactoryPtr m_Factory; | ||||||
ml::model::CModelFactory::TDataGathererPtr m_Gatherer; | ||||||
ml::model::CModelFactory::TModelPtr m_Model; | ||||||
}; | ||||||
|
||||||
BOOST_FIXTURE_TEST_CASE(testCountSample, CTestFixture) { | ||||||
|
@@ -986,12 +954,15 @@ BOOST_FIXTURE_TEST_CASE(testPrune, CTestFixture) { | |||||
features.push_back(model_t::E_IndividualTotalBucketCountByPerson); | ||||||
CModelFactory::TDataGathererPtr gatherer; | ||||||
CModelFactory::TModelPtr model_; | ||||||
this->makeModel(params, features, startTime, 0, gatherer, model_); | ||||||
this->makeModelT<CEventRateModelFactory>( | ||||||
params, features, startTime, model_t::E_EventRateOnline, gatherer, model_); | ||||||
CEventRateModel* model = dynamic_cast<CEventRateModel*>(model_.get()); | ||||||
BOOST_TEST_REQUIRE(model); | ||||||
CModelFactory::TDataGathererPtr expectedGatherer; | ||||||
CModelFactory::TModelPtr expectedModel_; | ||||||
this->makeModel(params, features, startTime, 0, expectedGatherer, expectedModel_); | ||||||
this->makeModelT<CEventRateModelFactory>(params, features, startTime, | ||||||
model_t::E_EventRateOnline, | ||||||
expectedGatherer, expectedModel_); | ||||||
CEventRateModel* expectedModel = | ||||||
dynamic_cast<CEventRateModel*>(expectedModel_.get()); | ||||||
BOOST_TEST_REQUIRE(expectedModel); | ||||||
|
@@ -2026,8 +1997,15 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) { | |||||
model_t::TFeatureVec features{feature}; | ||||||
CModelFactory::TDataGathererPtr gathererNoGap; | ||||||
CModelFactory::TModelPtr modelNoGap_; | ||||||
this->makeModel(params, features, startTime, 2, gathererNoGap, modelNoGap_); | ||||||
this->makeModelT<CEventRateModelFactory>(params, features, startTime, | ||||||
model_t::E_EventRateOnline, | ||||||
gathererNoGap, modelNoGap_); | ||||||
CEventRateModel* modelNoGap = dynamic_cast<CEventRateModel*>(modelNoGap_.get()); | ||||||
for (std::size_t i = 0u; i < 2; ++i) { | ||||||
BOOST_REQUIRE_EQUAL( | ||||||
std::size_t(i), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
this->addPerson("p" + core::CStringUtils::typeToString(i + 1), gathererNoGap)); | ||||||
} | ||||||
|
||||||
// p1: |1|1|1| | ||||||
// p2: |1|0|0| | ||||||
|
@@ -2041,8 +2019,15 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) { | |||||
|
||||||
CModelFactory::TDataGathererPtr gathererWithGap; | ||||||
CModelFactory::TModelPtr modelWithGap_; | ||||||
this->makeModel(params, features, startTime, 2, gathererWithGap, modelWithGap_); | ||||||
this->makeModelT<CEventRateModelFactory>(params, features, startTime, | ||||||
model_t::E_EventRateOnline, | ||||||
gathererWithGap, modelWithGap_); | ||||||
CEventRateModel* modelWithGap = dynamic_cast<CEventRateModel*>(modelWithGap_.get()); | ||||||
for (std::size_t i = 0u; i < 2; ++i) { | ||||||
BOOST_REQUIRE_EQUAL( | ||||||
std::size_t(i), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
this->addPerson("p" + core::CStringUtils::typeToString(i + 1), gathererWithGap)); | ||||||
} | ||||||
|
||||||
// p1: |1|1|0|0|0|0|0|0|0|0|1|1| | ||||||
// p1: |1|X|X|X|X|X|X|X|X|X|1|1| -> equal to |1|1|1| | ||||||
|
@@ -2108,8 +2093,9 @@ BOOST_FIXTURE_TEST_CASE(testExplicitNulls, CTestFixture) { | |||||
model_t::TFeatureVec features{feature}; | ||||||
CModelFactory::TDataGathererPtr gathererSkipGap; | ||||||
CModelFactory::TModelPtr modelSkipGap_; | ||||||
this->makeModel(params, features, startTime, 0, gathererSkipGap, | ||||||
modelSkipGap_, summaryCountField); | ||||||
this->makeModelT<CEventRateModelFactory>(params, features, startTime, | ||||||
model_t::E_EventRateOnline, gathererSkipGap, | ||||||
modelSkipGap_, {}, summaryCountField); | ||||||
CEventRateModel* modelSkipGap = dynamic_cast<CEventRateModel*>(modelSkipGap_.get()); | ||||||
|
||||||
// The idea here is to compare a model that has a gap skipped against a model | ||||||
|
@@ -2137,8 +2123,9 @@ BOOST_FIXTURE_TEST_CASE(testExplicitNulls, CTestFixture) { | |||||
|
||||||
CModelFactory::TDataGathererPtr gathererExNull; | ||||||
CModelFactory::TModelPtr modelExNullGap_; | ||||||
this->makeModel(params, features, startTime, 0, gathererExNull, | ||||||
modelExNullGap_, summaryCountField); | ||||||
this->makeModelT<CEventRateModelFactory>(params, features, startTime, | ||||||
model_t::E_EventRateOnline, gathererExNull, | ||||||
modelExNullGap_, {}, summaryCountField); | ||||||
CEventRateModel* modelExNullGap = | ||||||
dynamic_cast<CEventRateModel*>(modelExNullGap_.get()); | ||||||
|
||||||
|
@@ -2413,14 +2400,17 @@ BOOST_FIXTURE_TEST_CASE(testSummaryCountZeroRecordsAreIgnored, CTestFixture) { | |||||
|
||||||
CModelFactory::TDataGathererPtr gathererWithZeros; | ||||||
CModelFactory::TModelPtr modelWithZerosPtr; | ||||||
this->makeModel(params, {model_t::E_IndividualCountByBucketAndPerson}, startTime, | ||||||
0, gathererWithZeros, modelWithZerosPtr, summaryCountField); | ||||||
this->makeModelT<CEventRateModelFactory>( | ||||||
params, {model_t::E_IndividualCountByBucketAndPerson}, startTime, | ||||||
model_t::E_EventRateOnline, gathererWithZeros, modelWithZerosPtr, {}, | ||||||
summaryCountField); | ||||||
CEventRateModel& modelWithZeros = static_cast<CEventRateModel&>(*modelWithZerosPtr); | ||||||
|
||||||
CModelFactory::TDataGathererPtr gathererNoZeros; | ||||||
CModelFactory::TModelPtr modelNoZerosPtr; | ||||||
this->makeModel(params, {model_t::E_IndividualCountByBucketAndPerson}, startTime, | ||||||
0, gathererNoZeros, modelNoZerosPtr, summaryCountField); | ||||||
this->makeModelT<CEventRateModelFactory>( | ||||||
params, {model_t::E_IndividualCountByBucketAndPerson}, startTime, | ||||||
model_t::E_EventRateOnline, gathererNoZeros, modelNoZerosPtr, {}, summaryCountField); | ||||||
CEventRateModel& modelNoZeros = static_cast<CEventRateModel&>(*modelNoZerosPtr); | ||||||
|
||||||
// The idea here is to compare a model that has records with summary count of zero | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to cast expected types with Boost.Test. The places where we do this in existing code are a hangover from CppUnit.