Skip to content

Commit

Permalink
SdcaBinaryTrainer, SdcaMultiClassTrainer, SdcaRegressionTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
abgoswam committed Jan 14, 2019
1 parent 2a77e5e commit f0b9565
Show file tree
Hide file tree
Showing 31 changed files with 422 additions and 181 deletions.
17 changes: 8 additions & 9 deletions docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

namespace Microsoft.ML.Samples.Dynamic
{
Expand Down Expand Up @@ -59,15 +60,13 @@ public static void SDCA_BinaryClassification()
// If we wanted to specify more advanced parameters for the algorithm,
// we could do so by tweaking the 'advancedSetting'.
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent
(labelColumn: "Sentiment",
featureColumn: "Features",
advancedSettings: s=>
{
s.ConvergenceTolerance = 0.01f; // The learning rate for adjusting bias from being regularized
s.NumThreads = 2; // Degree of lock-free parallelism
})
);
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
new SdcaBinaryTrainer.Options {
LabelColumn = "Sentiment",
FeatureColumn = "Features",
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
NumThreads = 2, // Degree of lock-free parallelism
}));

// Run Cross-Validation on this second pipeline.
var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3);
Expand Down
55 changes: 24 additions & 31 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Arguments),
[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SdcaBinaryTrainer.UserNameValue,
SdcaBinaryTrainer.LoadNameValue,
Expand Down Expand Up @@ -253,21 +253,19 @@ protected enum MetricKind

private const string RegisterName = nameof(SdcaTrainerBase<TArgs, TTransformer, TModel>);

private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, Action<TArgs> advancedSettings = null)
private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn)
{
var args = new TArgs();

// Apply the advanced args, if the user supplied any.
advancedSettings?.Invoke(args);
args.FeatureColumn = featureColumn;
args.LabelColumn = labelColumn.Name;
return args;
}

internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
SchemaShape.Column weight = default, Action<TArgs> advancedSettings = null, float? l2Const = null,
SchemaShape.Column weight = default, float? l2Const = null,
float? l1Threshold = null, int? maxIterations = null)
: this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations)
: this(env, ArgsInit(featureColumn, labelColumn), labelColumn, weight, l2Const, l1Threshold, maxIterations)
{
}

Expand Down Expand Up @@ -1398,13 +1396,13 @@ public void Add(Double summand)
}
}

public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Arguments, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Options, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
{
public const string LoadNameValue = "SDCA";

internal const string UserNameValue = "Fast Linear (SA-SDCA)";

public sealed class Arguments : ArgumentsBase
public sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
Expand Down Expand Up @@ -1449,21 +1447,16 @@ internal override void Check(IHostEnvironment env)
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct method signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public SdcaBinaryTrainer(IHostEnvironment env,
internal SdcaBinaryTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
ISupportSdcaClassificationLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null,
Action<Arguments> advancedSettings = null)
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings,
l2Const, l1Threshold, maxIterations)
int? maxIterations = null)
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn),
l2Const, l1Threshold, maxIterations)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Expand Down Expand Up @@ -1503,11 +1496,11 @@ public SdcaBinaryTrainer(IHostEnvironment env,
_outputColumns = outCols.ToArray();
}

internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args,
internal SdcaBinaryTrainer(IHostEnvironment env, Options options,
string featureColumn, string labelColumn, string weightColumn = null)
: base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
: base(env, options, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
{
_loss = args.LossFunction.CreateComponent(env);
_loss = options.LossFunction.CreateComponent(env);
Loss = _loss;
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
_positiveInstanceWeight = Args.PositiveInstanceWeight;
Expand Down Expand Up @@ -1544,8 +1537,8 @@ internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args,

}

public SdcaBinaryTrainer(IHostEnvironment env, Arguments args)
: this(env, args, args.FeatureColumn, args.LabelColumn)
internal SdcaBinaryTrainer(IHostEnvironment env, Options options)
: this(env, options, options.FeatureColumn, options.LabelColumn)
{
}

Expand Down Expand Up @@ -1731,15 +1724,15 @@ internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
/// Initializes a new instance of <see cref="StochasticGradientDescentClassificationTrainer"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="args">Advanced arguments to the algorithm.</param>
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options args)
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
/// <param name="options">Advanced arguments to the algorithm.</param>
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options options)
: base(env, options.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit))
{
args.Check(env);
_loss = args.LossFunction.CreateComponent(env);
options.Check(env);
_loss = options.LossFunction.CreateComponent(env);
Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true);
NeedShuffle = args.Shuffle;
_args = args;
NeedShuffle = options.Shuffle;
_args = options;
}

protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand Down Expand Up @@ -1979,14 +1972,14 @@ public static partial class Sdca
ShortName = SdcaBinaryTrainer.LoadNameValue,
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentBinaryClassifier""]/*'/>" })]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Arguments input)
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSDCA");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new SdcaBinaryTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
Expand Down
31 changes: 13 additions & 18 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
using Microsoft.ML.Training;
using Float = System.Single;

[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Arguments),
[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Options),
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SdcaMultiClassTrainer.UserNameValue,
SdcaMultiClassTrainer.LoadNameValue,
Expand All @@ -29,14 +29,14 @@ namespace Microsoft.ML.Trainers
{
// SDCA linear multiclass trainer.
/// <include file='doc.xml' path='doc/members/member[@name="SDCA"]/*' />
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Arguments, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Options, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
{
public const string LoadNameValue = "SDCAMC";
public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)";
public const string ShortName = "sasdcamc";
internal const string Summary = "The SDCA linear multi-class classification trainer.";

public sealed class Arguments : ArgumentsBase
public sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
Expand All @@ -57,41 +57,36 @@ public sealed class Arguments : ArgumentsBase
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct method signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public SdcaMultiClassTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
ISupportSdcaClassificationLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null,
Action<Arguments> advancedSettings = null)
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings,
l2Const, l1Threshold, maxIterations)
int? maxIterations = null)
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights),
l2Const, l1Threshold, maxIterations)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
_loss = loss ?? Args.LossFunction.CreateComponent(env);
Loss = _loss;
}

internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args,
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options,
string featureColumn, string labelColumn, string weightColumn = null)
: base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
: base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
{
Host.CheckValue(labelColumn, nameof(labelColumn));
Host.CheckValue(featureColumn, nameof(featureColumn));

_loss = args.LossFunction.CreateComponent(env);
_loss = options.LossFunction.CreateComponent(env);
Loss = _loss;
}

internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
: this(env, args, args.FeatureColumn, args.LabelColumn)
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options)
: this(env, options, options.FeatureColumn, options.LabelColumn)
{
}

Expand Down Expand Up @@ -455,14 +450,14 @@ public static partial class Sdca
ShortName = SdcaMultiClassTrainer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentClassifier""]/*' />" })]
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments input)
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSDCA");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
() => new SdcaMultiClassTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
Expand Down
Loading

0 comments on commit f0b9565

Please sign in to comment.