Skip to content

Commit

Permalink
CV macro with stratification column doesn't work (#213)
Browse files Browse the repository at this point in the history
* Reduce number of hash bits in stratification column and add a unit test.

* Address PR comments.
  • Loading branch information
yaeldMS authored May 23, 2018
1 parent 2207a27 commit 73d894b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/Microsoft.ML/Runtime/EntryPoints/TrainTestSplit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ public static string CreateStratificationColumn(IHost host, ref IDataView data,
new HashJoinTransform.Arguments
{
Column = new[] { new HashJoinTransform.Column { Name = stratCol, Source = stratificationColumn } },
Join = true
Join = true,
HashBits = 30
}, data);
}

Expand Down
68 changes: 68 additions & 0 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -330,5 +330,73 @@ public void TestCrossValidationMacro()
}
}
}

[Fact]
public void TestCrossValidationMacroWithStratification()
{
var dataPath = GetDataPath(@"breast-cancer.txt");
using (var env = new TlcEnvironment())
{
var subGraph = env.CreateExperiment();

var nop = new ML.Transforms.NoOperation();
var nopOutput = subGraph.Add(nop);

var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentBinaryClassifier
{
TrainingData = nopOutput.OutputData,
NumThreads = 1
};
var learnerOutput = subGraph.Add(learnerInput);

var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
{
TransformModels = new ArrayVar<ITransformModel>(nopOutput.Model),
PredictorModel = learnerOutput.PredictorModel
};
var modelCombineOutput = subGraph.Add(modelCombine);

var experiment = env.CreateExperiment();
var importInput = new ML.Data.TextLoader(dataPath);
importInput.Arguments.Column = new ML.Data.TextLoaderColumn[]
{
new ML.Data.TextLoaderColumn { Name = "Label", Source = new[] { new ML.Data.TextLoaderRange(0) } },
new ML.Data.TextLoaderColumn { Name = "Strat", Source = new[] { new ML.Data.TextLoaderRange(1) } },
new ML.Data.TextLoaderColumn { Name = "Features", Source = new[] { new ML.Data.TextLoaderRange(2, 9) } }
};
var importOutput = experiment.Add(importInput);

var crossValidate = new ML.Models.CrossValidator
{
Data = importOutput.Data,
Nodes = subGraph,
TransformModel = null,
StratificationColumn = "Strat"
};
crossValidate.Inputs.Data = nop.Data;
crossValidate.Outputs.Model = modelCombineOutput.PredictorModel;
var crossValidateOutput = experiment.Add(crossValidate);

experiment.Compile();
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
experiment.Run();
var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]);

var schema = data.Schema;
var b = schema.TryGetColumnIndex("AUC", out int metricCol);
Assert.True(b);
using (var cursor = data.GetRowCursor(col => col == metricCol))
{
var getter = cursor.GetGetter<double>(metricCol);
b = cursor.MoveNext();
Assert.True(b);
double val = 0;
getter(ref val);
Assert.Equal(0.99, val, 2);
b = cursor.MoveNext();
Assert.False(b);
}
}
}
}
}

0 comments on commit 73d894b

Please sign in to comment.