Skip to content

Commit

Permalink
StochasticGradientDescentClassificationTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
abgoswam committed Jan 13, 2019
1 parent 60eb9d5 commit 2a77e5e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 38 deletions.
40 changes: 22 additions & 18 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"lc",
"sasdca")]

[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Arguments),
[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
StochasticGradientDescentClassificationTrainer.UserNameValue,
StochasticGradientDescentClassificationTrainer.LoadNameValue,
Expand Down Expand Up @@ -69,6 +69,13 @@ private protected LinearTrainerBase(IHostEnvironment env, string featureColumn,
{
}

private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
SchemaShape.Column weightColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn),
labelColumn, weightColumn)
{
}

private protected override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
Expand Down Expand Up @@ -1595,7 +1602,7 @@ public sealed class StochasticGradientDescentClassificationTrainer :
internal const string UserNameValue = "Hogwild SGD (binary)";
internal const string ShortName = "HogwildSGD";

public sealed class Arguments : LearnerInputBaseWithWeight
public sealed class Options : LearnerInputBaseWithWeight
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportClassificationLossFactory LossFunction = new LogLossFactory();
Expand Down Expand Up @@ -1670,7 +1677,7 @@ internal static class Defaults
}

private readonly IClassificationLoss _loss;
private readonly Arguments _args;
private readonly Options _args;

protected override bool ShuffleData => _args.Shuffle;

Expand All @@ -1689,29 +1696,24 @@ internal static class Defaults
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
/// <param name="l2Weight">The L2 regularizer constant.</param>
/// <param name="loss">The loss function to use.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
int maxIterations = Arguments.Defaults.MaxIterations,
double initLearningRate = Arguments.Defaults.InitLearningRate,
float l2Weight = Arguments.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null,
Action<Arguments> advancedSettings = null)
int maxIterations = Options.Defaults.MaxIterations,
double initLearningRate = Options.Defaults.InitLearningRate,
float l2Weight = Options.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null)
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));

_args = new Arguments();
_args = new Options();
_args.MaxIterations = maxIterations;
_args.InitLearningRate = initLearningRate;
_args.L2Weight = l2Weight;

// Apply the advanced args, if the user supplied any.
advancedSettings?.Invoke(_args);

_args.FeatureColumn = featureColumn;
_args.LabelColumn = labelColumn;
_args.WeightColumn = weightColumn;
Expand All @@ -1728,8 +1730,10 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
/// <summary>
/// Initializes a new instance of <see cref="StochasticGradientDescentClassificationTrainer"/>
/// </summary>
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args)
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), args.WeightColumn)
/// <param name="env">The environment to use.</param>
/// <param name="args">Advanced arguments to the algorithm.</param>
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options args)
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
args.Check(env);
_loss = args.LossFunction.CreateComponent(env);
Expand Down Expand Up @@ -1948,14 +1952,14 @@ private protected override void CheckLabel(RoleMappedData examples, out int weig
}

[TlcModule.EntryPoint(Name = "Trainers.StochasticGradientDescentBinaryClassifier", Desc = "Train an Hogwild SGD binary model.", UserName = UserNameValue, ShortName = ShortName)]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainHogwildSGD");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new StochasticGradientDescentClassificationTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
Expand Down
28 changes: 20 additions & 8 deletions src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace Microsoft.ML
{
using LRArguments = LogisticRegression.Arguments;
using SgdArguments = StochasticGradientDescentClassificationTrainer.Arguments;
using SgdOptions = StochasticGradientDescentClassificationTrainer.Options;

/// <summary>
/// TrainerEstimator extension methods.
Expand All @@ -31,20 +31,32 @@ public static class StandardLearnersCatalog
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
/// <param name="l2Weight">The L2 regularization constant.</param>
/// <param name="loss">The loss function to use.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public static StochasticGradientDescentClassificationTrainer StochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
int maxIterations = SgdArguments.Defaults.MaxIterations,
double initLearningRate = SgdArguments.Defaults.InitLearningRate,
float l2Weight = SgdArguments.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null,
Action<SgdArguments> advancedSettings = null)
int maxIterations = SgdOptions.Defaults.MaxIterations,
double initLearningRate = SgdOptions.Defaults.InitLearningRate,
float l2Weight = SgdOptions.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new StochasticGradientDescentClassificationTrainer(env, labelColumn, featureColumn, weights, maxIterations, initLearningRate, l2Weight, loss);
}

/// <summary>
/// Predict a target using a linear binary classification model trained with the <see cref="StochasticGradientDescentClassificationTrainer"/> trainer.
/// </summary>
/// <param name="ctx">The binary classificaiton context trainer object.</param>
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
public static StochasticGradientDescentClassificationTrainer StochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
SgdOptions advancedSettings)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
return new StochasticGradientDescentClassificationTrainer(env, labelColumn, featureColumn, weights, maxIterations, initLearningRate, l2Weight, loss, advancedSettings);

return new StochasticGradientDescentClassificationTrainer(env, advancedSettings);
}

/// <summary>
Expand Down
47 changes: 40 additions & 7 deletions src/Microsoft.ML.StaticPipe/SgdStatic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Microsoft.ML.StaticPipe
{
using Arguments = StochasticGradientDescentClassificationTrainer.Arguments;
using Options = StochasticGradientDescentClassificationTrainer.Options;

/// <summary>
/// Binary Classification trainer estimators.
Expand All @@ -27,7 +27,6 @@ public static class SgdStaticExtensions
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
/// <param name="l2Weight">The L2 regularization constant.</param>
/// <param name="loss">The loss function to use.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
Expand All @@ -38,17 +37,51 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
Scalar<bool> label,
Vector<float> features,
Scalar<float> weights = null,
int maxIterations = Arguments.Defaults.MaxIterations,
double initLearningRate = Arguments.Defaults.InitLearningRate,
float l2Weight = Arguments.Defaults.L2Weight,
int maxIterations = Options.Defaults.MaxIterations,
double initLearningRate = Options.Defaults.InitLearningRate,
float l2Weight = Options.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null,
Action<Arguments> advancedSettings = null,
Action<IPredictorWithFeatureWeights<float>> onFit = null)
{
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new StochasticGradientDescentClassificationTrainer(env, labelName, featuresName, weightsName, maxIterations, initLearningRate, l2Weight, loss, advancedSettings);
var trainer = new StochasticGradientDescentClassificationTrainer(env, labelName, featuresName, weightsName, maxIterations, initLearningRate, l2Weight, loss);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
return trainer;
}, label, features, weights);

return rec.Output;
}

/// <summary>
/// Predict a target using a linear binary classification model trained with the <see cref="Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer"/> trainer.
/// </summary>
/// <param name="ctx">The binary classificaiton context trainer object.</param>
/// <param name="label">The name of the label column.</param>
/// <param name="features">The name of the feature column.</param>
/// <param name="weights">The name for the example weight column.</param>
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
/// the linear model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to
/// be informed about what was learnt.</param>
/// <returns>The predicted output.</returns>
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) StochasticGradientDescentClassificationTrainer(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
Scalar<bool> label,
Vector<float> features,
Scalar<float> weights,
Options advancedSettings,
Action<IPredictorWithFeatureWeights<float>> onFit = null)
{
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new StochasticGradientDescentClassificationTrainer(env, advancedSettings);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
Expand Down
2 changes: 1 addition & 1 deletion test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Trainers.PoissonRegressor Train an Poisson regression model. Microsoft.ML.Traine
Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.SdcaBinaryTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMultiClass Microsoft.ML.Trainers.SdcaMultiClassTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Trainers.StochasticGradientDescentClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer TrainSymSgd Microsoft.ML.Trainers.SymSgd.SymSgdClassificationTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
Transforms.ApproximateBootstrapSampler Approximate bootstrap sampling. Microsoft.ML.Transforms.BootstrapSample GetSample Microsoft.ML.Transforms.BootstrapSamplingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class. Microsoft.ML.EntryPoints.ScoreModel RenameBinaryPredictionScoreColumns Microsoft.ML.EntryPoints.ScoreModel+RenameBinaryPredictionScoreColumnsInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Expand Down
3 changes: 1 addition & 2 deletions test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -940,8 +940,7 @@ public void HogwildSGDBinaryClassification()
var est = reader.MakeNewEstimator()
.Append(r => (r.label, preds: ctx.Trainers.StochasticGradientDescentClassificationTrainer(r.label, r.features,
l2Weight: 0,
onFit: (p) => { pred = p; },
advancedSettings: s => s.NumThreads = 1)));
onFit: (p) => { pred = p; })));

var pipe = reader.Append(est);

Expand Down
5 changes: 4 additions & 1 deletion test/Microsoft.ML.Tests/FeatureContributionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.RunTests;
using Microsoft.ML.Trainers;
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;
using Xunit;
Expand Down Expand Up @@ -152,7 +153,9 @@ public void TestSDCABinary()
[Fact]
public void TestSGDBinary()
{
TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticGradientDescent(advancedSettings: args => { args.NumThreads = 1; }), GetSparseDataset(TaskType.BinaryClassification, 100), "SGDBinary");
TestFeatureContribution(ML.BinaryClassification.Trainers.StochasticGradientDescent(
new StochasticGradientDescentClassificationTrainer.Options { NumThreads = 1}),
GetSparseDataset(TaskType.BinaryClassification, 100), "SGDBinary");
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public void KMeansEstimator()
public void TestEstimatorHogwildSGD()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline();
var trainer = new StochasticGradientDescentClassificationTrainer(Env, "Label", "Features");
var trainer = ML.BinaryClassification.Trainers.StochasticGradientDescent();
var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);

Expand Down

0 comments on commit 2a77e5e

Please sign in to comment.