Skip to content

Commit

Permalink
Towards dotnet#1798 .
Browse files Browse the repository at this point in the history
This PR addresses the estimators inside HalLearners:

Two public extension methods, one for simple arguments and the other for advanced options
Delete unecessary constructors
Pass Options objects as arguments instead of Action delegate
Rename Arguments to Options
Rename Options objects as options (instead of args or advancedSettings used so far)
  • Loading branch information
sfilipi committed Jan 18, 2019
1 parent bb92c06 commit b097a76
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 82 deletions.
53 changes: 47 additions & 6 deletions src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Trainers.HalLearners;
using Microsoft.ML.Trainers.SymSgd;
using Microsoft.ML.Transforms.Projections;
Expand All @@ -22,16 +23,36 @@ public static class HalLearnersCatalog
/// <param name="labelColumn">The labelColumn column.</param>
/// <param name="featureColumn">The features column.</param>
/// <param name="weights">The weights column.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
Action<OlsLinearRegressionTrainer.Arguments> advancedSettings = null)
string weights = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
var options = new OlsLinearRegressionTrainer.Options
{
LabelColumn = labelColumn,
FeatureColumn = featureColumn,
WeightColumn = weights != null ? Optional<string>.Explicit(weights) : Optional<string>.Implicit(DefaultColumnNames.Weight)
};

return new OlsLinearRegressionTrainer(env, options);
}

/// <summary>
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="options">Algorithm advanced options.</param>
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
OlsLinearRegressionTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new OlsLinearRegressionTrainer(env, labelColumn, featureColumn, weights, advancedSettings);
return new OlsLinearRegressionTrainer(env, options);
}

/// <summary>
Expand All @@ -44,11 +65,31 @@ public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionCon
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
Action<SymSgdClassificationTrainer.Arguments> advancedSettings = null)
Action<SymSgdClassificationTrainer.Options> advancedSettings = null)
{
Contracts.CheckValue(ctx, nameof(ctx));
var env = CatalogUtils.GetEnvironment(ctx);
var options = new SymSgdClassificationTrainer.Options
{
LabelColumn = labelColumn,
FeatureColumn = featureColumn,
};

return new SymSgdClassificationTrainer(env, options);
}

/// <summary>
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
/// </summary>
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
/// <param name="options">Algorithm advanced options.</param>
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
SymSgdClassificationTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));
var env = CatalogUtils.GetEnvironment(ctx);
return new SymSgdClassificationTrainer(env, labelColumn, featureColumn, advancedSettings);
return new SymSgdClassificationTrainer(env, options);
}

/// <summary>
Expand Down
54 changes: 13 additions & 41 deletions src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
using Microsoft.ML.Trainers.HalLearners;
using Microsoft.ML.Training;

[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Arguments),
[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Options),
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
OlsLinearRegressionTrainer.UserNameValue,
OlsLinearRegressionTrainer.LoadNameValue,
Expand All @@ -34,9 +34,10 @@
namespace Microsoft.ML.Trainers.HalLearners
{
/// <include file='doc.xml' path='doc/members/member[@name="OLS"]/*' />
[BestFriend]
public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase<RegressionPredictionTransformer<OlsLinearRegressionModelParameters>, OlsLinearRegressionModelParameters>
{
public sealed class Arguments : LearnerInputBaseWithWeight
public sealed class Options : LearnerInputBaseWithWeight
{
// Adding L2 regularization turns this into a form of ridge regression,
// rather than, strictly speaking, ordinary least squares. But it is an
Expand All @@ -46,13 +47,16 @@ public sealed class Arguments : LearnerInputBaseWithWeight
[TlcModule.SweepableDiscreteParamAttribute("L2Weight", new object[] { 1e-6f, 0.1f, 1f })]
public float L2Weight = 1e-6f;

/// <summary>
/// Whether to calculate per parameter significance statistics.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to calculate per parameter significance statistics", ShortName = "sig")]
public bool PerParameterSignificance = true;
}

public const string LoadNameValue = "OLSLinearRegression";
public const string UserNameValue = "Ordinary Least Squares (Regression)";
public const string ShortName = "ols";
internal const string LoadNameValue = "OLSLinearRegression";
internal const string UserNameValue = "Ordinary Least Squares (Regression)";
internal const string ShortName = "ols";
internal const string Summary = "The ordinary least square regression fits the target function as a linear function of the numerical features "
+ "that minimizes the square loss function.";

Expand All @@ -68,24 +72,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
/// <summary>
/// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The name of the labelColumn column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="weights">The name for the optional example weight column.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public OlsLinearRegressionTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
Action<Arguments> advancedSettings = null)
: this(env, ArgsInit(featureColumn, labelColumn, weights, advancedSettings))
{
}

/// <summary>
/// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
/// </summary>
internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
internal OlsLinearRegressionTrainer(IHostEnvironment env, Options args)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Expand All @@ -95,21 +82,6 @@ internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
_perParameterSignificance = args.PerParameterSignificance;
}

private static Arguments ArgsInit(string featureColumn,
string labelColumn,
string weightColumn,
Action<Arguments> advancedSettings)
{
var args = new Arguments();

// Apply the advanced args, if the user supplied any.
advancedSettings?.Invoke(args);
args.FeatureColumn = featureColumn;
args.LabelColumn = labelColumn;
args.WeightColumn = weightColumn != null ? Optional<string>.Explicit(weightColumn) : Optional<string>.Implicit(DefaultColumnNames.Weight);
return args;
}

protected override RegressionPredictionTransformer<OlsLinearRegressionModelParameters> MakeTransformer(OlsLinearRegressionModelParameters model, Schema trainSchema)
=> new RegressionPredictionTransformer<OlsLinearRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name);

Expand Down Expand Up @@ -518,14 +490,14 @@ public static void Pptri(Layout layout, UpLo uplo, int n, Double[] ap)
UserName = UserNameValue,
ShortName = ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name=""OLS""]/*' />" })]
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Arguments input)
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainOLS");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.RegressionOutput>(host, input,
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.RegressionOutput>(host, input,
() => new OlsLinearRegressionTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
Expand Down Expand Up @@ -579,7 +551,7 @@ private static VersionInfo GetVersionInfo()
/// are all null. A model may not have per parameter statistics because either
/// there were not more examples than parameters in the model, or because they
/// were explicitly suppressed in training by setting
/// <see cref="OlsLinearRegressionTrainer.Arguments.PerParameterSignificance"/>
/// <see cref="OlsLinearRegressionTrainer.Options.PerParameterSignificance"/>
/// to false.
/// </summary>
public bool HasStatistics => _standardErrors != null;
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System.Runtime.CompilerServices;
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners.StaticPipe" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]
Expand Down
67 changes: 36 additions & 31 deletions src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Arguments),
[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SymSgdClassificationTrainer.UserNameValue,
SymSgdClassificationTrainer.LoadNameValue,
Expand All @@ -33,48 +33,78 @@ namespace Microsoft.ML.Trainers.SymSgd
using TPredictor = IPredictorWithFeatureWeights<float>;

/// <include file='doc.xml' path='doc/members/member[@name="SymSGD"]/*' />
[BestFriend]
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor>
{
internal const string LoadNameValue = "SymbolicSGD";
internal const string UserNameValue = "Symbolic SGD (binary)";
internal const string ShortName = "SymSGD";

public sealed class Arguments : LearnerInputBaseWithLabel
public sealed class Options : LearnerInputBaseWithLabel
{
/// <summary>
/// Degree of lock-free parallelism. Determinism not guaranteed.
/// Multi-threading is not supported currently.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Determinism not guaranteed. " +
"Multi-threading is not supported currently.", ShortName = "nt")]
public int? NumberOfThreads;

/// <summary>
/// Number of passes over the data.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of passes over the data.", ShortName = "iter", SortOrder = 50)]
[TGUI(SuggestedSweeps = "1,5,10,20,30,40,50")]
[TlcModule.SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 })]
public int NumberOfIterations = 50;

/// <summary>
/// Tolerance for difference in average loss in consecutive passes.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance for difference in average loss in consecutive passes.", ShortName = "tol")]
public float Tolerance = 1e-4f;

/// <summary>
/// Learning rate.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", NullName = "<Auto>", SortOrder = 51)]
[TGUI(SuggestedSweeps = "<Auto>,1e1,1e0,1e-1,1e-2,1e-3")]
[TlcModule.SweepableDiscreteParam("LearningRate", new object[] { "<Auto>", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f })]
public float? LearningRate;

/// <summary>
/// L2 regularization.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization", ShortName = "l2", SortOrder = 52)]
[TGUI(SuggestedSweeps = "0.0,1e-5,1e-5,1e-6,1e-7")]
[TlcModule.SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f })]
public float L2Regularization;

/// <summary>
/// The number of iterations each thread learns a local model until combining it with the
/// global model. Low value means more updated global model and high value means less cache traffic.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of iterations each thread learns a local model until combining it with the " +
"global model. Low value means more updated global model and high value means less cache traffic.", ShortName = "freq", NullName = "<Auto>")]
[TGUI(SuggestedSweeps = "<Auto>,5,20")]
[TlcModule.SweepableDiscreteParam("UpdateFrequency", new object[] { "<Auto>", 5, 20 })]
public int? UpdateFrequency;

/// <summary>
/// The acceleration memory budget in MB.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "The acceleration memory budget in MB", ShortName = "accelMemBudget")]
public long MemorySize = 1024;

/// <summary>
/// Set to <see langword="true" /> causes the data to shuffle.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data?", ShortName = "shuf")]
public bool Shuffle = true;

/// <summary>
/// Apply weight to the positive class, for imbalanced data.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
public float PositiveInstanceWeight = 1;

Expand All @@ -88,7 +118,7 @@ public void Check(IExceptionContext ectx)
}

public override TrainerInfo Info { get; }
private readonly Arguments _args;
private readonly Options _args;

/// <summary>
/// This method ensures that the data meets the requirements of this trainer and its
Expand Down Expand Up @@ -152,32 +182,7 @@ private protected override TPredictor TrainModelCore(TrainContext context)
/// <summary>
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public SymSgdClassificationTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
Action<Arguments> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn),
TrainerUtils.MakeBoolScalarLabel(labelColumn))
{
_args = new Arguments();

// Apply the advanced args, if the user supplied any.
_args.Check(Host);
advancedSettings?.Invoke(_args);
_args.FeatureColumn = featureColumn;
_args.LabelColumn = labelColumn;

Info = new TrainerInfo(supportIncrementalTrain: true);
}

/// <summary>
/// Initializes a new instance of <see cref="SymSgdClassificationTrainer"/>
/// </summary>
internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
internal SymSgdClassificationTrainer(IHostEnvironment env, Options args)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
{
Expand Down Expand Up @@ -218,14 +223,14 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
UserName = SymSgdClassificationTrainer.UserNameValue,
ShortName = SymSgdClassificationTrainer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name=""SymSGD""]/*' />" })]
public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Arguments input)
public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSymSGD");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new SymSgdClassificationTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
Expand Down
Loading

0 comments on commit b097a76

Please sign in to comment.