diff --git a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
index 62420c9b0f0..9afdc1b97f2 100644
--- a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
+++ b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs
@@ -4,6 +4,7 @@
using System;
using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
using Microsoft.ML.Trainers.HalLearners;
using Microsoft.ML.Trainers.SymSgd;
using Microsoft.ML.Transforms.Projections;
@@ -22,16 +23,36 @@ public static class HalLearnersCatalog
/// The labelColumn column.
/// The features column.
/// The weights column.
- /// Algorithm advanced settings.
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
- string weights = null,
- Action advancedSettings = null)
+ string weights = null)
+ {
+ Contracts.CheckValue(ctx, nameof(ctx));
+ var env = CatalogUtils.GetEnvironment(ctx);
+ var options = new OlsLinearRegressionTrainer.Options
+ {
+ LabelColumn = labelColumn,
+ FeatureColumn = featureColumn,
+ WeightColumn = weights != null ? Optional.Explicit(weights) : Optional.Implicit(DefaultColumnNames.Weight)
+ };
+
+ return new OlsLinearRegressionTrainer(env, options);
+ }
+
+ ///
+ /// Predict a target using a linear regression model trained with the .
+ ///
+ /// The .
+ /// Algorithm advanced options.
+ public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
+ OlsLinearRegressionTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
+ Contracts.CheckValue(options, nameof(options));
+
var env = CatalogUtils.GetEnvironment(ctx);
- return new OlsLinearRegressionTrainer(env, labelColumn, featureColumn, weights, advancedSettings);
+ return new OlsLinearRegressionTrainer(env, options);
}
///
@@ -44,11 +65,31 @@ public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionCon
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
- Action advancedSettings = null)
+ Action advancedSettings = null)
+ {
+ Contracts.CheckValue(ctx, nameof(ctx));
+ var env = CatalogUtils.GetEnvironment(ctx);
+ var options = new SymSgdClassificationTrainer.Options
+ {
+ LabelColumn = labelColumn,
+ FeatureColumn = featureColumn,
+ };
+
+ return new SymSgdClassificationTrainer(env, options);
+ }
+
+ ///
+ /// Predict a target using a linear binary classification model trained with the .
+ ///
+ /// The .
+ /// Algorithm advanced options.
+ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
+ SymSgdClassificationTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
+ Contracts.CheckValue(options, nameof(options));
var env = CatalogUtils.GetEnvironment(ctx);
- return new SymSgdClassificationTrainer(env, labelColumn, featureColumn, advancedSettings);
+ return new SymSgdClassificationTrainer(env, options);
}
///
diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
index a4f42bbb39e..0cc61f4b22b 100644
--- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
+++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
@@ -19,7 +19,7 @@
using Microsoft.ML.Trainers.HalLearners;
using Microsoft.ML.Training;
-[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Arguments),
+[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Options),
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
OlsLinearRegressionTrainer.UserNameValue,
OlsLinearRegressionTrainer.LoadNameValue,
@@ -34,9 +34,10 @@
namespace Microsoft.ML.Trainers.HalLearners
{
///
+ [BestFriend]
public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase, OlsLinearRegressionModelParameters>
{
- public sealed class Arguments : LearnerInputBaseWithWeight
+ public sealed class Options : LearnerInputBaseWithWeight
{
// Adding L2 regularization turns this into a form of ridge regression,
// rather than, strictly speaking, ordinary least squares. But it is an
@@ -46,13 +47,16 @@ public sealed class Arguments : LearnerInputBaseWithWeight
[TlcModule.SweepableDiscreteParamAttribute("L2Weight", new object[] { 1e-6f, 0.1f, 1f })]
public float L2Weight = 1e-6f;
+ ///
+ /// Whether to calculate per parameter significance statistics.
+ ///
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to calculate per parameter significance statistics", ShortName = "sig")]
public bool PerParameterSignificance = true;
}
- public const string LoadNameValue = "OLSLinearRegression";
- public const string UserNameValue = "Ordinary Least Squares (Regression)";
- public const string ShortName = "ols";
+ internal const string LoadNameValue = "OLSLinearRegression";
+ internal const string UserNameValue = "Ordinary Least Squares (Regression)";
+ internal const string ShortName = "ols";
internal const string Summary = "The ordinary least square regression fits the target function as a linear function of the numerical features "
+ "that minimizes the square loss function.";
@@ -68,24 +72,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
///
/// Initializes a new instance of
///
- /// The environment to use.
- /// The name of the labelColumn column.
- /// The name of the feature column.
- /// The name for the optional example weight column.
- /// A delegate to apply all the advanced arguments to the algorithm.
- public OlsLinearRegressionTrainer(IHostEnvironment env,
- string labelColumn = DefaultColumnNames.Label,
- string featureColumn = DefaultColumnNames.Features,
- string weights = null,
- Action advancedSettings = null)
- : this(env, ArgsInit(featureColumn, labelColumn, weights, advancedSettings))
- {
- }
-
- ///
- /// Initializes a new instance of
- ///
- internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
+ internal OlsLinearRegressionTrainer(IHostEnvironment env, Options args)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
@@ -95,21 +82,6 @@ internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
_perParameterSignificance = args.PerParameterSignificance;
}
- private static Arguments ArgsInit(string featureColumn,
- string labelColumn,
- string weightColumn,
- Action advancedSettings)
- {
- var args = new Arguments();
-
- // Apply the advanced args, if the user supplied any.
- advancedSettings?.Invoke(args);
- args.FeatureColumn = featureColumn;
- args.LabelColumn = labelColumn;
- args.WeightColumn = weightColumn;
- return args;
- }
-
protected override RegressionPredictionTransformer MakeTransformer(OlsLinearRegressionModelParameters model, Schema trainSchema)
=> new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name);
@@ -518,14 +490,14 @@ public static void Pptri(Layout layout, UpLo uplo, int n, Double[] ap)
UserName = UserNameValue,
ShortName = ShortName,
XmlInclude = new[] { @"" })]
- public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Arguments input)
+ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainOLS");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
- return LearnerEntryPointsUtils.Train(host, input,
+ return LearnerEntryPointsUtils.Train(host, input,
() => new OlsLinearRegressionTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
@@ -579,7 +551,7 @@ private static VersionInfo GetVersionInfo()
/// are all null. A model may not have per parameter statistics because either
/// there were not more examples than parameters in the model, or because they
/// were explicitly suppressed in training by setting
- ///
+ ///
/// to false.
///
public bool HasStatistics => _standardErrors != null;
diff --git a/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs b/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs
index 378fcf459ae..9a3e26f765c 100644
--- a/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs
@@ -5,6 +5,8 @@
using System.Runtime.CompilerServices;
using Microsoft.ML;
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners.StaticPipe" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]
diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
index c4a57c65478..c1141bed81b 100644
--- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
+++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
@@ -20,7 +20,7 @@
using Microsoft.ML.Training;
using Microsoft.ML.Transforms;
-[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Arguments),
+[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SymSgdClassificationTrainer.UserNameValue,
SymSgdClassificationTrainer.LoadNameValue,
@@ -33,48 +33,78 @@ namespace Microsoft.ML.Trainers.SymSgd
using TPredictor = IPredictorWithFeatureWeights;
///
+ [BestFriend]
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase, TPredictor>
{
internal const string LoadNameValue = "SymbolicSGD";
internal const string UserNameValue = "Symbolic SGD (binary)";
internal const string ShortName = "SymSGD";
- public sealed class Arguments : LearnerInputBaseWithLabel
+ public sealed class Options : LearnerInputBaseWithLabel
{
+ ///
+ /// Degree of lock-free parallelism. Determinism not guaranteed.
+ /// Multi-threading is not supported currently.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Determinism not guaranteed. " +
"Multi-threading is not supported currently.", ShortName = "nt")]
public int? NumberOfThreads;
+ ///
+ /// Number of passes over the data.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of passes over the data.", ShortName = "iter", SortOrder = 50)]
[TGUI(SuggestedSweeps = "1,5,10,20,30,40,50")]
[TlcModule.SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 })]
public int NumberOfIterations = 50;
+ ///
+ /// Tolerance for difference in average loss in consecutive passes.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance for difference in average loss in consecutive passes.", ShortName = "tol")]
public float Tolerance = 1e-4f;
+ ///
+ /// Learning rate.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", NullName = "", SortOrder = 51)]
[TGUI(SuggestedSweeps = ",1e1,1e0,1e-1,1e-2,1e-3")]
[TlcModule.SweepableDiscreteParam("LearningRate", new object[] { "", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f })]
public float? LearningRate;
+ ///
+ /// L2 regularization.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization", ShortName = "l2", SortOrder = 52)]
[TGUI(SuggestedSweeps = "0.0,1e-5,1e-5,1e-6,1e-7")]
[TlcModule.SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f })]
public float L2Regularization;
+ ///
+ /// The number of iterations each thread learns a local model until combining it with the
+ /// global model. Low value means more updated global model and high value means less cache traffic.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of iterations each thread learns a local model until combining it with the " +
"global model. Low value means more updated global model and high value means less cache traffic.", ShortName = "freq", NullName = "")]
[TGUI(SuggestedSweeps = ",5,20")]
[TlcModule.SweepableDiscreteParam("UpdateFrequency", new object[] { "", 5, 20 })]
public int? UpdateFrequency;
+ ///
+ /// The acceleration memory budget in MB.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "The acceleration memory budget in MB", ShortName = "accelMemBudget")]
public long MemorySize = 1024;
+ ///
+ /// Set to causes the data to shuffle.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data?", ShortName = "shuf")]
public bool Shuffle = true;
+ ///
+ /// Apply weight to the positive class, for imbalanced data.
+ ///
[Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")]
public float PositiveInstanceWeight = 1;
@@ -88,7 +118,7 @@ public void Check(IExceptionContext ectx)
}
public override TrainerInfo Info { get; }
- private readonly Arguments _args;
+ private readonly Options _args;
///
/// This method ensures that the data meets the requirements of this trainer and its
@@ -152,32 +182,7 @@ private protected override TPredictor TrainModelCore(TrainContext context)
///
/// Initializes a new instance of
///
- /// The private instance of .
- /// The name of the label column.
- /// The name of the feature column.
- /// A delegate to apply all the advanced arguments to the algorithm.
- public SymSgdClassificationTrainer(IHostEnvironment env,
- string labelColumn = DefaultColumnNames.Label,
- string featureColumn = DefaultColumnNames.Features,
- Action advancedSettings = null)
- : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn),
- TrainerUtils.MakeBoolScalarLabel(labelColumn))
- {
- _args = new Arguments();
-
- // Apply the advanced args, if the user supplied any.
- _args.Check(Host);
- advancedSettings?.Invoke(_args);
- _args.FeatureColumn = featureColumn;
- _args.LabelColumn = labelColumn;
-
- Info = new TrainerInfo(supportIncrementalTrain: true);
- }
-
- ///
- /// Initializes a new instance of
- ///
- internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args)
+ internal SymSgdClassificationTrainer(IHostEnvironment env, Options args)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
{
@@ -218,14 +223,14 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
UserName = SymSgdClassificationTrainer.UserNameValue,
ShortName = SymSgdClassificationTrainer.ShortName,
XmlInclude = new[] { @"" })]
- public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Arguments input)
+ public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSymSGD");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
- return LearnerEntryPointsUtils.Train(host, input,
+ return LearnerEntryPointsUtils.Train(host, input,
() => new SymSgdClassificationTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs
index 048ba54da7d..3494d4b48ef 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs
@@ -13,7 +13,7 @@ public partial class TrainerEstimators
public void TestEstimatorOlsLinearRegression()
{
var dataView = GetRegressionPipeline();
- var trainer = new OlsLinearRegressionTrainer(Env, "Label", "Features");
+ var trainer = new OlsLinearRegressionTrainer(Env, new OlsLinearRegressionTrainer.Options());
TestEstimatorCore(trainer, dataView);
var model = trainer.Fit(dataView);
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs
index 30dc030aed4..522363d70c5 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs
@@ -16,7 +16,7 @@ public partial class TrainerEstimators
public void TestEstimatorSymSgdClassificationTrainer()
{
(var pipe, var dataView) = GetBinaryClassificationPipeline();
- var trainer = new SymSgdClassificationTrainer(Env, "Label", "Features");
+ var trainer = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options());
var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);
@@ -35,10 +35,10 @@ public void TestEstimatorSymSgdInitPredictor()
var initPredictor = new SdcaBinaryTrainer(Env, "Label", "Features").Fit(transformedData);
var data = initPredictor.Transform(transformedData);
- var withInitPredictor = new SymSgdClassificationTrainer(Env, "Label", "Features").Train(transformedData, initialPredictor: initPredictor.Model);
+ var withInitPredictor = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options()).Train(transformedData, initialPredictor: initPredictor.Model);
var outInitData = withInitPredictor.Transform(transformedData);
- var notInitPredictor = new SymSgdClassificationTrainer(Env, "Label", "Features").Train(transformedData);
+ var notInitPredictor = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options()).Train(transformedData);
var outNoInitData = notInitPredictor.Transform(transformedData);
int numExamples = 10;