Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify API for advanced settings. (SDCA) #2093

Merged
merged 13 commits into from
Jan 18, 2019
17 changes: 8 additions & 9 deletions docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

namespace Microsoft.ML.Samples.Dynamic
{
Expand Down Expand Up @@ -59,15 +60,13 @@ public static void SDCA_BinaryClassification()
// If we wanted to specify more advanced parameters for the algorithm,
// we could do so by tweaking the 'advancedSetting'.
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent
(labelColumn: "Sentiment",
featureColumn: "Features",
advancedSettings: s=>
{
s.ConvergenceTolerance = 0.01f; // The learning rate for adjusting bias from being regularized
s.NumThreads = 2; // Degree of lock-free parallelism
})
);
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
new SdcaBinaryTrainer.Options {
LabelColumn = "Sentiment",
FeatureColumn = "Features",
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
NumThreads = 2, // Degree of lock-free parallelism
}));

// Run Cross-Validation on this second pipeline.
var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3);
Expand Down
93 changes: 42 additions & 51 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Arguments),
[assembly: LoadableClass(typeof(SdcaBinaryTrainer), typeof(SdcaBinaryTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SdcaBinaryTrainer.UserNameValue,
SdcaBinaryTrainer.LoadNameValue,
"LinearClassifier",
"lc",
"sasdca")]

[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Arguments),
[assembly: LoadableClass(typeof(StochasticGradientDescentClassificationTrainer), typeof(StochasticGradientDescentClassificationTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
StochasticGradientDescentClassificationTrainer.UserNameValue,
StochasticGradientDescentClassificationTrainer.LoadNameValue,
Expand Down Expand Up @@ -69,6 +69,13 @@ private protected LinearTrainerBase(IHostEnvironment env, string featureColumn,
{
}

private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
SchemaShape.Column weightColumn)
Copy link
Member

@sfilipi sfilipi Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this to account for the corresponding changes in line #1728-1730

I saw some tests fail saying "Weights" column not found. The fix was to specify options.WeightColumn.IsExplicit in the call below.

TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit)


In reply to: 248006972 [](ancestors = 248006972)

Copy link
Member Author

@abgoswam abgoswam Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got rid of this


In reply to: 248058147 [](ancestors = 248058147,248006972)

: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn),
labelColumn, weightColumn)
{
}

private protected override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
Expand Down Expand Up @@ -246,21 +253,19 @@ protected enum MetricKind

private const string RegisterName = nameof(SdcaTrainerBase<TArgs, TTransformer, TModel>);

private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, Action<TArgs> advancedSettings = null)
private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn)
{
var args = new TArgs();

// Apply the advanced args, if the user supplied any.
advancedSettings?.Invoke(args);
args.FeatureColumn = featureColumn;
args.LabelColumn = labelColumn.Name;
return args;
}

internal SdcaTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn,
Copy link
Member

@sfilipi sfilipi Jan 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal SdcaTrainerBase [](start = 6, length = 26)

delete #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2100


In reply to: 248770895 [](ancestors = 248770895)

SchemaShape.Column weight = default, Action<TArgs> advancedSettings = null, float? l2Const = null,
SchemaShape.Column weight = default, float? l2Const = null,
float? l1Threshold = null, int? maxIterations = null)
: this(env, ArgsInit(featureColumn, labelColumn, advancedSettings), labelColumn, weight, l2Const, l1Threshold, maxIterations)
: this(env, ArgsInit(featureColumn, labelColumn), labelColumn, weight, l2Const, l1Threshold, maxIterations)
{
}

Expand Down Expand Up @@ -1391,12 +1396,13 @@ public void Add(Double summand)
}
}

public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Arguments, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
public sealed class SdcaBinaryTrainer : SdcaTrainerBase<SdcaBinaryTrainer.Options, BinaryPredictionTransformer<TScalarPredictor>, TScalarPredictor>
{
public const string LoadNameValue = "SDCA";
Copy link
Member

@sfilipi sfilipi Jan 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

all those strings should be internal.. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there some separate issue for it that I can link to this comment ?

I would like to keep the PRs for public API focused on API rather than fixing issues across the codebase :)


In reply to: 248771027 [](ancestors = 248771027)


internal const string UserNameValue = "Fast Linear (SA-SDCA)";

public sealed class Arguments : ArgumentsBase
public sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
Expand Down Expand Up @@ -1441,21 +1447,16 @@ internal override void Check(IHostEnvironment env)
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct method signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public SdcaBinaryTrainer(IHostEnvironment env,
internal SdcaBinaryTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
ISupportSdcaClassificationLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null,
Action<Arguments> advancedSettings = null)
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings,
l2Const, l1Threshold, maxIterations)
int? maxIterations = null)
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn),
l2Const, l1Threshold, maxIterations)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Expand Down Expand Up @@ -1495,11 +1496,10 @@ public SdcaBinaryTrainer(IHostEnvironment env,
_outputColumns = outCols.ToArray();
}

internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args,
string featureColumn, string labelColumn, string weightColumn = null)
: base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
internal SdcaBinaryTrainer(IHostEnvironment env, Options options)
Copy link
Member

@singlis singlis Jan 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here as I made with the SDCARegression -- does SdcaBinary not have common args or is it always options? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will eventually have just 1 constructor (the one with Options)

#2100


In reply to: 248866446 [](ancestors = 248866446)

: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
{
_loss = args.LossFunction.CreateComponent(env);
_loss = options.LossFunction.CreateComponent(env);
Loss = _loss;
Info = new TrainerInfo(calibration: !(_loss is LogLoss));
_positiveInstanceWeight = Args.PositiveInstanceWeight;
Expand Down Expand Up @@ -1533,12 +1533,6 @@ internal SdcaBinaryTrainer(IHostEnvironment env, Arguments args,
};

_outputColumns = outCols.ToArray();

}

public SdcaBinaryTrainer(IHostEnvironment env, Arguments args)
Copy link
Member

@sfilipi sfilipi Jan 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public SdcaBinaryTrainer(IHostEnvironment env, Arguments args) [](start = 6, length = 64)

i think this is needed for the SignatureTrainer. Doesn't have to be public. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the constructor above has this signature. so its not really going away.


In reply to: 248771919 [](ancestors = 248771919)

: this(env, args, args.FeatureColumn, args.LabelColumn)
{
}

protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
Expand Down Expand Up @@ -1594,7 +1588,7 @@ public sealed class StochasticGradientDescentClassificationTrainer :
internal const string UserNameValue = "Hogwild SGD (binary)";
internal const string ShortName = "HogwildSGD";

public sealed class Arguments : LearnerInputBaseWithWeight
public sealed class Options : LearnerInputBaseWithWeight
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportClassificationLossFactory LossFunction = new LogLossFactory();
Copy link
Member

@sfilipi sfilipi Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISupportClassificationLossFactory [](start = 19, length = 33)

XML docs #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #2154 to fix this issue in bulk for several learners


In reply to: 248022036 [](ancestors = 248022036)

Expand Down Expand Up @@ -1669,7 +1663,7 @@ internal static class Defaults
}

private readonly IClassificationLoss _loss;
private readonly Arguments _args;
private readonly Options _args;

protected override bool ShuffleData => _args.Shuffle;

Expand All @@ -1688,29 +1682,24 @@ internal static class Defaults
/// <param name="initLearningRate">The initial learning rate used by SGD.</param>
/// <param name="l2Weight">The L2 regularizer constant.</param>
/// <param name="loss">The loss function to use.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
Copy link
Member

@sfilipi sfilipi Jan 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal StochasticGradientDescentClassificationTraine [](start = 8, length = 54)

I'd remove all ctors but the one with the (IHostEnvironment env, Arguments) signature, unless is needed for inheritance. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly. thats the point of issue #2100. but that would be a separate PR.

these PRs focus on public API which is P0 (as per consensus in email thread)


In reply to: 248772619 [](ancestors = 248772619)

string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weightColumn = null,
Copy link
Member

@singlis singlis Jan 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be set to Defaults.WeightColumn? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to take closer look at how 'weights' are being used, or if they are being used at all

please see #2175.


In reply to: 248867308 [](ancestors = 248867308)

int maxIterations = Arguments.Defaults.MaxIterations,
double initLearningRate = Arguments.Defaults.InitLearningRate,
float l2Weight = Arguments.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null,
Action<Arguments> advancedSettings = null)
int maxIterations = Options.Defaults.MaxIterations,
double initLearningRate = Options.Defaults.InitLearningRate,
float l2Weight = Options.Defaults.L2Weight,
ISupportClassificationLossFactory loss = null)
: base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), weightColumn)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));

_args = new Arguments();
_args = new Options();
Copy link
Contributor

@artidoro artidoro Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_args [](start = 12, length = 5)

Could you make this _options? So that it is consistent with the general renaming? #Resolved

_args.MaxIterations = maxIterations;
_args.InitLearningRate = initLearningRate;
_args.L2Weight = l2Weight;

// Apply the advanced args, if the user supplied any.
advancedSettings?.Invoke(_args);

_args.FeatureColumn = featureColumn;
_args.LabelColumn = labelColumn;
_args.WeightColumn = weightColumn;
Expand All @@ -1727,14 +1716,16 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env,
/// <summary>
/// Initializes a new instance of <see cref="StochasticGradientDescentClassificationTrainer"/>
/// </summary>
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Arguments args)
: base(env, args.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn), args.WeightColumn)
/// <param name="env">The environment to use.</param>
/// <param name="options">Advanced arguments to the algorithm.</param>
internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Options options)
: base(env, options.FeatureColumn, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit))
Copy link
Contributor

@artidoro artidoro Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn, options.WeightColumn.IsExplicit) [](start = 102, length = 92)

I think you can replace this with:

options.WeightColumn.IsExplicit ? options.WeightColumn : null

That way you won't need the other constructor for the base. This is a strange problem from wanting to keep the optional for the weightColumn which is used in Maml. #Resolved

Copy link
Member Author

@abgoswam abgoswam Jan 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfilipi who is taking a look at how best to handle the weights column

In this PR, I would like to keep it as is. This has ramifications for several tests.


In reply to: 248102195 [](ancestors = 248102195)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a pretty simple fix, which reduces the number of constructors in the base class, I would check if it works :)


In reply to: 248104112 [](ancestors = 248104112,248102195)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified. Thanks for debugging !


In reply to: 248104779 [](ancestors = 248104779,248104112,248102195)

{
args.Check(env);
_loss = args.LossFunction.CreateComponent(env);
options.Check(env);
_loss = options.LossFunction.CreateComponent(env);
Info = new TrainerInfo(calibration: !(_loss is LogLoss), supportIncrementalTrain: true);
NeedShuffle = args.Shuffle;
_args = args;
NeedShuffle = options.Shuffle;
_args = options;
}

protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand Down Expand Up @@ -1947,14 +1938,14 @@ private protected override void CheckLabel(RoleMappedData examples, out int weig
}

[TlcModule.EntryPoint(Name = "Trainers.StochasticGradientDescentBinaryClassifier", Desc = "Train an Hogwild SGD binary model.", UserName = UserNameValue, ShortName = ShortName)]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainHogwildSGD");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new StochasticGradientDescentClassificationTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
Expand All @@ -1974,14 +1965,14 @@ public static partial class Sdca
ShortName = SdcaBinaryTrainer.LoadNameValue,
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentBinaryClassifier""]/*'/>" })]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Arguments input)
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, SdcaBinaryTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSDCA");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<SdcaBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new SdcaBinaryTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
Expand Down
Loading