Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Distinguish missing and empty categorical values #1034

Merged
merged 6 commits into from
Mar 4, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Explicit missing value string
tveasey committed Mar 3, 2020
commit 04f9d17b19840434c334739d4b0cdde39effdef8
2 changes: 2 additions & 0 deletions include/api/CDataFrameAnalysisSpecification.h
Original file line number Diff line number Diff line change
@@ -66,6 +66,7 @@ class API_EXPORT CDataFrameAnalysisSpecification {
static const std::string THREADS;
static const std::string TEMPORARY_DIRECTORY;
static const std::string RESULTS_FIELD;
static const std::string MISSING_FIELD_VALUE;
static const std::string CATEGORICAL_FIELD_NAMES;
static const std::string DISK_USAGE_ALLOWED;
static const std::string ANALYSIS;
@@ -203,6 +204,7 @@ class API_EXPORT CDataFrameAnalysisSpecification {
std::string m_ResultsField;
std::string m_JobId;
std::string m_AnalysisName;
std::string m_MissingFieldValue;
TStrVec m_CategoricalFieldNames;
bool m_DiskUsageAllowed;
// TODO Sparse table support
11 changes: 11 additions & 0 deletions include/core/CDataFrame.h
Original file line number Diff line number Diff line change
@@ -238,6 +238,9 @@ class CORE_EXPORT CDataFrame final {
//! The maximum number of distinct categorical fields we can faithfully represent.
static const std::size_t MAX_CATEGORICAL_CARDINALITY;

//! The default value indicating that a value is missing.
static const std::string DEFAULT_MISSING_STRING;

public:
//! \param[in] inMainMemory True if the data frame is stored in main memory.
//! \param[in] numberColumns The number of columns in the data frame.
@@ -443,6 +446,9 @@ class CORE_EXPORT CDataFrame final {
//! Write the column names.
void columnNames(TStrVec columnNames);

//! Write the string which indicates that a value is missing.
void missingString(std::string missing);

//! Write for which columns an empty string implies the value is missing.
void emptyIsMissing(TBoolVec emptyIsMissing);

@@ -577,7 +583,12 @@ class CORE_EXPORT CDataFrame final {
//! A lookup for the integer value of categories.
TStrSizeUMapVec m_CategoricalColumnValueLookup;

//! The string which indicates that a category is missing.
std::string m_MissingString;

//! Indicator vector for treating empty strings as missing values.
// TODO Remove once Java passes the correct value for the missing target
// for classification.
TBoolVec m_EmptyIsMissing;

//! Indicator vector of the columns which contain categorical values.
28 changes: 17 additions & 11 deletions lib/api/CDataFrameAnalysisSpecification.cc
Original file line number Diff line number Diff line change
@@ -28,18 +28,19 @@ namespace ml {
namespace api {

// These must be consistent with Java names.
const std::string CDataFrameAnalysisSpecification::JOB_ID("job_id");
const std::string CDataFrameAnalysisSpecification::ROWS("rows");
const std::string CDataFrameAnalysisSpecification::COLS("cols");
const std::string CDataFrameAnalysisSpecification::MEMORY_LIMIT("memory_limit");
const std::string CDataFrameAnalysisSpecification::THREADS("threads");
const std::string CDataFrameAnalysisSpecification::TEMPORARY_DIRECTORY("temp_dir");
const std::string CDataFrameAnalysisSpecification::RESULTS_FIELD("results_field");
const std::string CDataFrameAnalysisSpecification::JOB_ID{"job_id"};
const std::string CDataFrameAnalysisSpecification::ROWS{"rows"};
const std::string CDataFrameAnalysisSpecification::COLS{"cols"};
const std::string CDataFrameAnalysisSpecification::MEMORY_LIMIT{"memory_limit"};
const std::string CDataFrameAnalysisSpecification::THREADS{"threads"};
const std::string CDataFrameAnalysisSpecification::TEMPORARY_DIRECTORY{"temp_dir"};
const std::string CDataFrameAnalysisSpecification::RESULTS_FIELD{"results_field"};
const std::string CDataFrameAnalysisSpecification::MISSING_FIELD_VALUE{"missing_field_value"};
const std::string CDataFrameAnalysisSpecification::CATEGORICAL_FIELD_NAMES{"categorical_fields"};
const std::string CDataFrameAnalysisSpecification::DISK_USAGE_ALLOWED("disk_usage_allowed");
const std::string CDataFrameAnalysisSpecification::ANALYSIS("analysis");
const std::string CDataFrameAnalysisSpecification::NAME("name");
const std::string CDataFrameAnalysisSpecification::PARAMETERS("parameters");
const std::string CDataFrameAnalysisSpecification::DISK_USAGE_ALLOWED{"disk_usage_allowed"};
const std::string CDataFrameAnalysisSpecification::ANALYSIS{"analysis"};
const std::string CDataFrameAnalysisSpecification::NAME{"name"};
const std::string CDataFrameAnalysisSpecification::PARAMETERS{"parameters"};

namespace {
using TBoolVec = std::vector<bool>;
@@ -75,6 +76,8 @@ const CDataFrameAnalysisConfigReader CONFIG_READER{[] {
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(CDataFrameAnalysisSpecification::RESULTS_FIELD,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(CDataFrameAnalysisSpecification::MISSING_FIELD_VALUE,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(CDataFrameAnalysisSpecification::CATEGORICAL_FIELD_NAMES,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(CDataFrameAnalysisSpecification::DISK_USAGE_ALLOWED,
@@ -131,6 +134,8 @@ CDataFrameAnalysisSpecification::CDataFrameAnalysisSpecification(
m_TemporaryDirectory = parameters[TEMPORARY_DIRECTORY].fallback(std::string{});
m_JobId = parameters[JOB_ID].fallback(std::string{});
m_ResultsField = parameters[RESULTS_FIELD].fallback(DEFAULT_RESULT_FIELD);
m_MissingFieldValue = parameters[MISSING_FIELD_VALUE].fallback(
core::CDataFrame::DEFAULT_MISSING_STRING);
m_CategoricalFieldNames = parameters[CATEGORICAL_FIELD_NAMES].fallback(TStrVec{});
m_DiskUsageAllowed = parameters[DISK_USAGE_ALLOWED].fallback(DEFAULT_DISK_USAGE_ALLOWED);

@@ -189,6 +194,7 @@ CDataFrameAnalysisSpecification::makeDataFrame() {
? core::makeMainStorageDataFrame(m_NumberColumns)
: core::makeDiskStorageDataFrame(m_TemporaryDirectory,
m_NumberColumns, m_NumberRows);
result.first->missingString(m_MissingFieldValue);
result.first->reserve(m_NumberThreads, m_NumberColumns + this->numberExtraColumns());

return result;
39 changes: 39 additions & 0 deletions lib/api/unittest/CDataFrameAnalyzerTrainingTest.cc
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
#include <maths/CBoostedTree.h>
#include <maths/CBoostedTreeFactory.h>
#include <maths/CBoostedTreeLoss.h>
#include <maths/CDataFrameUtils.h>
#include <maths/CTools.h>

#include <api/CDataFrameAnalyzer.h>
@@ -369,6 +370,44 @@ void testOneRunOfBoostedTreeTrainingWithStateRecovery(F makeSpec, std::size_t it
}
}

BOOST_AUTO_TEST_CASE(testMissingString) {

// Test that the special missing value string is respected.

std::stringstream output;
auto outputWriterFactory = [&output]() {
return std::make_unique<core::CJsonOutputStreamWrapper>(output);
};

TStrVec fieldNames{"f1", "f2", "f3", "f4", "target", ".", "."};
TStrVec fieldValues{"a", "2.0", "3.0", "4.0", "5.0", "0", ""};

api::CDataFrameAnalyzer analyzer{
test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
test::CDataFrameAnalysisSpecificationFactory::regression(),
"target", 5, 5, 7000000, 0, 0, {"f1"}),
outputWriterFactory};

std::string a{"a"};
std::string b{"b"};
std::string missing{core::CDataFrame::DEFAULT_MISSING_STRING};
TBoolVec isMissing;
for (const auto& category : {a, missing, b, a, missing}) {
fieldValues[0] = category;
analyzer.handleRecord(fieldNames, fieldValues);
isMissing.push_back(category == missing);
}
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});

analyzer.dataFrame().readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
std::size_t i{0};
for (auto row = beginRows; row != endRows; ++row, ++i) {
BOOST_REQUIRE_EQUAL(isMissing[row->index()],
maths::CDataFrameUtils::isMissing((*row)[0]));
}
});
}

BOOST_AUTO_TEST_CASE(testRunBoostedTreeRegressionTraining) {

// Test the results the analyzer produces match running the regression directly.
18 changes: 16 additions & 2 deletions lib/core/CDataFrame.cc
Original file line number Diff line number Diff line change
@@ -124,8 +124,9 @@ CDataFrame::CDataFrame(bool inMainMemory,
const TWriteSliceToStoreFunc& writeSliceToStore)
: m_InMainMemory{inMainMemory}, m_NumberColumns{numberColumns},
m_RowCapacity{numberColumns}, m_SliceCapacityInRows{sliceCapacityInRows},
m_ReadAndWriteToStoreSyncStrategy{readAndWriteToStoreSyncStrategy}, m_WriteSliceToStore{writeSliceToStore},
m_ColumnNames(numberColumns), m_CategoricalColumnValues(numberColumns),
m_ReadAndWriteToStoreSyncStrategy{readAndWriteToStoreSyncStrategy},
m_WriteSliceToStore{writeSliceToStore}, m_ColumnNames(numberColumns),
m_CategoricalColumnValues(numberColumns), m_MissingString{DEFAULT_MISSING_STRING},
m_EmptyIsMissing(numberColumns, false),
m_ColumnIsCategorical(numberColumns, false) {
}
@@ -216,7 +217,13 @@ void CDataFrame::parseAndWriteRow(const TStrCRng& columnValues, const std::strin
auto stringToValue = [this](bool isCategorical, TStrSizeUMap& categoryLookup,
TStrVec& categories, bool emptyIsMissing,
const std::string& columnValue) {
if (columnValue == m_MissingString) {
++m_MissingValueCount;
return core::CFloatStorage{valueOfMissing()};
}

if (isCategorical) {
// TODO Remove when Java passes special missing value string.
if (columnValue.empty() && emptyIsMissing) {
return core::CFloatStorage{valueOfMissing()};
}
@@ -251,6 +258,7 @@ void CDataFrame::parseAndWriteRow(const TStrCRng& columnValues, const std::strin

double value;
if (columnValue.empty()) {
// TODO Remove when Java passes special missing value string.
++m_MissingValueCount;
return core::CFloatStorage{valueOfMissing()};
} else if (core::CStringUtils::stringToTypeSilent(columnValue, value) == false) {
@@ -300,6 +308,10 @@ void CDataFrame::columnNames(TStrVec columnNames) {
}
}

void CDataFrame::missingString(std::string missing) {
m_MissingString = std::move(missing);
}

void CDataFrame::emptyIsMissing(TBoolVec emptyIsMissing) {
if (emptyIsMissing.size() != m_NumberColumns) {
HANDLE_FATAL(<< "Internal error: expected '" << m_NumberColumns
@@ -374,6 +386,7 @@ std::size_t CDataFrame::memoryUsage() const {
std::size_t memory{CMemory::dynamicSize(m_ColumnNames)};
memory += CMemory::dynamicSize(m_CategoricalColumnValues);
memory += CMemory::dynamicSize(m_CategoricalColumnValueLookup);
memory += CMemory::dynamicSize(m_MissingString);
memory += CMemory::dynamicSize(m_EmptyIsMissing);
memory += CMemory::dynamicSize(m_ColumnIsCategorical);
memory += CMemory::dynamicSize(m_Slices);
@@ -630,6 +643,7 @@ bool CDataFrame::maskedRowsInSlice(ITR& maskedRow,

const std::size_t CDataFrame::MAX_CATEGORICAL_CARDINALITY{
1 << (std::numeric_limits<float>::digits)};
const std::string CDataFrame::DEFAULT_MISSING_STRING{"\0"};

CDataFrame::CDataFrameRowSliceWriter::CDataFrameRowSliceWriter(
std::size_t numberRows,
20 changes: 11 additions & 9 deletions lib/maths/CDataFrameUtils.cc
Original file line number Diff line number Diff line change
@@ -563,9 +563,11 @@ CDataFrameUtils::categoryFrequencies(std::size_t numberThreads,
[&](TDoubleVecVec& counts, TRowItr beginRows, TRowItr endRows) {
for (auto row = beginRows; row != endRows; ++row) {
for (std::size_t i : columnMask) {
std::size_t category{static_cast<std::size_t>((*row)[i])};
counts[i].resize(std::max(counts[i].size(), category + 1), 0.0);
counts[i][category] += 1.0;
if (isMissing((*row)[i]) == false) {
std::size_t category{static_cast<std::size_t>((*row)[i])};
counts[i].resize(std::max(counts[i].size(), category + 1), 0.0);
counts[i][category] += 1.0;
}
}
}
},
@@ -588,12 +590,12 @@ CDataFrameUtils::categoryFrequencies(std::size_t numberThreads,
readCategoryCounts, &rowMask),
copyCategoryCounts, reduceCategoryCounts, result) == false) {
HANDLE_FATAL(<< "Internal error: failed to calculate category"
<< " frequencies. Please report this problem.");
<< " frequencies. Please report this problem.")
return result;
}
} catch (const std::exception& e) {
HANDLE_FATAL(<< "Internal error: '" << e.what() << "' exception calculating"
<< " category frequencies. Please report this problem.");
<< " category frequencies. Please report this problem.")
}

double Z{rowMask.manhattan()};
@@ -628,7 +630,7 @@ CDataFrameUtils::meanValueOfTargetForCategories(const CColumnValue& target,
[&](TMeanAccumulatorVecVec& means_, TRowItr beginRows, TRowItr endRows) {
for (auto row = beginRows; row != endRows; ++row) {
for (std::size_t i : columnMask) {
if (isMissing(target(*row)) == false) {
if (isMissing((*row)[i]) == false && isMissing(target(*row)) == false) {
std::size_t category{static_cast<std::size_t>((*row)[i])};
means_[i].resize(std::max(means_[i].size(), category + 1));
means_[i][category].add(target(*row));
@@ -654,12 +656,12 @@ CDataFrameUtils::meanValueOfTargetForCategories(const CColumnValue& target,
if (doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), readColumnMeans, &rowMask),
copyColumnMeans, reduceColumnMeans, means) == false) {
HANDLE_FATAL(<< "Internal error: failed to calculate mean target values"
<< " for categories. Please report this problem.");
<< " for categories. Please report this problem.")
return result;
}
} catch (const std::exception& e) {
HANDLE_FATAL(<< "Internal error: '" << e.what() << "' exception calculating"
<< " mean target values for categories. Please report this problem.");
<< " mean target values for categories. Please report this problem.")
return result;
}
for (std::size_t i = 0; i < result.size(); ++i) {
@@ -760,7 +762,7 @@ CDataFrameUtils::maximumMinimumRecallDecisionThreshold(std::size_t numberThreads
TQuantileSketchVec classProbabilityClassOneQuantiles;
if (doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), readQuantiles, &rowMask),
copyQuantiles, reduceQuantiles, classProbabilityClassOneQuantiles) == false) {
HANDLE_FATAL(<< "Failed to compute category quantiles");
HANDLE_FATAL(<< "Failed to compute category quantiles")
return 0.5;
}