diff --git a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs index 62420c9b0f0..9afdc1b97f2 100644 --- a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs +++ b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs @@ -4,6 +4,7 @@ using System; using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Trainers.SymSgd; using Microsoft.ML.Transforms.Projections; @@ -22,16 +23,36 @@ public static class HalLearnersCatalog /// The labelColumn column. /// The features column. /// The weights column. - /// Algorithm advanced settings. public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, - string weights = null, - Action advancedSettings = null) + string weights = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + var options = new OlsLinearRegressionTrainer.Options + { + LabelColumn = labelColumn, + FeatureColumn = featureColumn, + WeightColumn = weights != null ? Optional.Explicit(weights) : Optional.Implicit(DefaultColumnNames.Weight) + }; + + return new OlsLinearRegressionTrainer(env, options); + } + + /// + /// Predict a target using a linear regression model trained with the . + /// + /// The . + /// Algorithm advanced options. + public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx, + OlsLinearRegressionTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); + var env = CatalogUtils.GetEnvironment(ctx); - return new OlsLinearRegressionTrainer(env, labelColumn, featureColumn, weights, advancedSettings); + return new OlsLinearRegressionTrainer(env, options); } /// @@ -44,11 +65,31 @@ public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionCon public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string labelColumn = DefaultColumnNames.Label, string featureColumn = DefaultColumnNames.Features, - Action advancedSettings = null) + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + var options = new SymSgdClassificationTrainer.Options + { + LabelColumn = labelColumn, + FeatureColumn = featureColumn, + }; + + return new SymSgdClassificationTrainer(env, options); + } + + /// + /// Predict a target using a linear binary classification model trained with the . + /// + /// The . + /// Algorithm advanced options. + public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + SymSgdClassificationTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(options, nameof(options)); var env = CatalogUtils.GetEnvironment(ctx); - return new SymSgdClassificationTrainer(env, labelColumn, featureColumn, advancedSettings); + return new SymSgdClassificationTrainer(env, options); } /// diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index a4f42bbb39e..0cc61f4b22b 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -19,7 +19,7 @@ using Microsoft.ML.Trainers.HalLearners; using Microsoft.ML.Training; -[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Arguments), +[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Options), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, OlsLinearRegressionTrainer.UserNameValue, OlsLinearRegressionTrainer.LoadNameValue, @@ -34,9 +34,10 @@ namespace Microsoft.ML.Trainers.HalLearners { /// + [BestFriend] public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase, OlsLinearRegressionModelParameters> { - public sealed class Arguments : LearnerInputBaseWithWeight + public sealed class Options : LearnerInputBaseWithWeight { // Adding L2 regularization turns this into a form of ridge regression, // rather than, strictly speaking, ordinary least squares. But it is an @@ -46,13 +47,16 @@ public sealed class Arguments : LearnerInputBaseWithWeight [TlcModule.SweepableDiscreteParamAttribute("L2Weight", new object[] { 1e-6f, 0.1f, 1f })] public float L2Weight = 1e-6f; + /// + /// Whether to calculate per parameter significance statistics. + /// [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to calculate per parameter significance statistics", ShortName = "sig")] public bool PerParameterSignificance = true; } - public const string LoadNameValue = "OLSLinearRegression"; - public const string UserNameValue = "Ordinary Least Squares (Regression)"; - public const string ShortName = "ols"; + internal const string LoadNameValue = "OLSLinearRegression"; + internal const string UserNameValue = "Ordinary Least Squares (Regression)"; + internal const string ShortName = "ols"; internal const string Summary = "The ordinary least square regression fits the target function as a linear function of the numerical features " + "that minimizes the square loss function."; @@ -68,24 +72,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight /// /// Initializes a new instance of /// - /// The environment to use. - /// The name of the labelColumn column. - /// The name of the feature column. - /// The name for the optional example weight column. - /// A delegate to apply all the advanced arguments to the algorithm. - public OlsLinearRegressionTrainer(IHostEnvironment env, - string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weights = null, - Action advancedSettings = null) - : this(env, ArgsInit(featureColumn, labelColumn, weights, advancedSettings)) - { - } - - /// - /// Initializes a new instance of - /// - internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) + internal OlsLinearRegressionTrainer(IHostEnvironment env, Options args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { @@ -95,21 +82,6 @@ internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) _perParameterSignificance = args.PerParameterSignificance; } - private static Arguments ArgsInit(string featureColumn, - string labelColumn, - string weightColumn, - Action advancedSettings) - { - var args = new Arguments(); - - // Apply the advanced args, if the user supplied any. - advancedSettings?.Invoke(args); - args.FeatureColumn = featureColumn; - args.LabelColumn = labelColumn; - args.WeightColumn = weightColumn; - return args; - } - protected override RegressionPredictionTransformer MakeTransformer(OlsLinearRegressionModelParameters model, Schema trainSchema) => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); @@ -518,14 +490,14 @@ public static void Pptri(Layout layout, UpLo uplo, int n, Double[] ap) UserName = UserNameValue, ShortName = ShortName, XmlInclude = new[] { @"" })] - public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Arguments input) + public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainOLS"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new OlsLinearRegressionTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); @@ -579,7 +551,7 @@ private static VersionInfo GetVersionInfo() /// are all null. A model may not have per parameter statistics because either /// there were not more examples than parameters in the model, or because they /// were explicitly suppressed in training by setting - /// + /// /// to false. /// public bool HasStatistics => _standardErrors != null; diff --git a/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs b/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs index 378fcf459ae..9a3e26f765c 100644 --- a/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.HalLearners/Properties/AssemblyInfo.cs @@ -5,6 +5,8 @@ using System.Runtime.CompilerServices; using Microsoft.ML; +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] + [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners.StaticPipe" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)] diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs index c4a57c65478..c1141bed81b 100644 --- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs @@ -20,7 +20,7 @@ using Microsoft.ML.Training; using Microsoft.ML.Transforms; -[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Arguments), +[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Options), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, SymSgdClassificationTrainer.UserNameValue, SymSgdClassificationTrainer.LoadNameValue, @@ -33,48 +33,78 @@ namespace Microsoft.ML.Trainers.SymSgd using TPredictor = IPredictorWithFeatureWeights; /// + [BestFriend] public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase, TPredictor> { internal const string LoadNameValue = "SymbolicSGD"; internal const string UserNameValue = "Symbolic SGD (binary)"; internal const string ShortName = "SymSGD"; - public sealed class Arguments : LearnerInputBaseWithLabel + public sealed class Options : LearnerInputBaseWithLabel { + /// + /// Degree of lock-free parallelism. Determinism not guaranteed. + /// Multi-threading is not supported currently. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Determinism not guaranteed. " + "Multi-threading is not supported currently.", ShortName = "nt")] public int? NumberOfThreads; + /// + /// Number of passes over the data. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Number of passes over the data.", ShortName = "iter", SortOrder = 50)] [TGUI(SuggestedSweeps = "1,5,10,20,30,40,50")] [TlcModule.SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 })] public int NumberOfIterations = 50; + /// + /// Tolerance for difference in average loss in consecutive passes. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance for difference in average loss in consecutive passes.", ShortName = "tol")] public float Tolerance = 1e-4f; + /// + /// Learning rate. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", NullName = "", SortOrder = 51)] [TGUI(SuggestedSweeps = ",1e1,1e0,1e-1,1e-2,1e-3")] [TlcModule.SweepableDiscreteParam("LearningRate", new object[] { "", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f })] public float? LearningRate; + /// + /// L2 regularization. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization", ShortName = "l2", SortOrder = 52)] [TGUI(SuggestedSweeps = "0.0,1e-5,1e-5,1e-6,1e-7")] [TlcModule.SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f })] public float L2Regularization; + /// + /// The number of iterations each thread learns a local model until combining it with the + /// global model. Low value means more updated global model and high value means less cache traffic. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The number of iterations each thread learns a local model until combining it with the " + "global model. Low value means more updated global model and high value means less cache traffic.", ShortName = "freq", NullName = "")] [TGUI(SuggestedSweeps = ",5,20")] [TlcModule.SweepableDiscreteParam("UpdateFrequency", new object[] { "", 5, 20 })] public int? UpdateFrequency; + /// + /// The acceleration memory budget in MB. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "The acceleration memory budget in MB", ShortName = "accelMemBudget")] public long MemorySize = 1024; + /// + /// Set to causes the data to shuffle. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data?", ShortName = "shuf")] public bool Shuffle = true; + /// + /// Apply weight to the positive class, for imbalanced data. + /// [Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")] public float PositiveInstanceWeight = 1; @@ -88,7 +118,7 @@ public void Check(IExceptionContext ectx) } public override TrainerInfo Info { get; } - private readonly Arguments _args; + private readonly Options _args; /// /// This method ensures that the data meets the requirements of this trainer and its @@ -152,32 +182,7 @@ private protected override TPredictor TrainModelCore(TrainContext context) /// /// Initializes a new instance of /// - /// The private instance of . - /// The name of the label column. - /// The name of the feature column. - /// A delegate to apply all the advanced arguments to the algorithm. - public SymSgdClassificationTrainer(IHostEnvironment env, - string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), - TrainerUtils.MakeBoolScalarLabel(labelColumn)) - { - _args = new Arguments(); - - // Apply the advanced args, if the user supplied any. - _args.Check(Host); - advancedSettings?.Invoke(_args); - _args.FeatureColumn = featureColumn; - _args.LabelColumn = labelColumn; - - Info = new TrainerInfo(supportIncrementalTrain: true); - } - - /// - /// Initializes a new instance of - /// - internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args) + internal SymSgdClassificationTrainer(IHostEnvironment env, Options args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { @@ -218,14 +223,14 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc UserName = SymSgdClassificationTrainer.UserNameValue, ShortName = SymSgdClassificationTrainer.ShortName, XmlInclude = new[] { @"" })] - public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Arguments input) + public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainSymSGD"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - return LearnerEntryPointsUtils.Train(host, input, + return LearnerEntryPointsUtils.Train(host, input, () => new SymSgdClassificationTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs index 048ba54da7d..3494d4b48ef 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/OlsLinearRegressionTests.cs @@ -13,7 +13,7 @@ public partial class TrainerEstimators public void TestEstimatorOlsLinearRegression() { var dataView = GetRegressionPipeline(); - var trainer = new OlsLinearRegressionTrainer(Env, "Label", "Features"); + var trainer = new OlsLinearRegressionTrainer(Env, new OlsLinearRegressionTrainer.Options()); TestEstimatorCore(trainer, dataView); var model = trainer.Fit(dataView); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs index 30dc030aed4..522363d70c5 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs @@ -16,7 +16,7 @@ public partial class TrainerEstimators public void TestEstimatorSymSgdClassificationTrainer() { (var pipe, var dataView) = GetBinaryClassificationPipeline(); - var trainer = new SymSgdClassificationTrainer(Env, "Label", "Features"); + var trainer = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options()); var pipeWithTrainer = pipe.Append(trainer); TestEstimatorCore(pipeWithTrainer, dataView); @@ -35,10 +35,10 @@ public void TestEstimatorSymSgdInitPredictor() var initPredictor = new SdcaBinaryTrainer(Env, "Label", "Features").Fit(transformedData); var data = initPredictor.Transform(transformedData); - var withInitPredictor = new SymSgdClassificationTrainer(Env, "Label", "Features").Train(transformedData, initialPredictor: initPredictor.Model); + var withInitPredictor = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options()).Train(transformedData, initialPredictor: initPredictor.Model); var outInitData = withInitPredictor.Transform(transformedData); - var notInitPredictor = new SymSgdClassificationTrainer(Env, "Label", "Features").Train(transformedData); + var notInitPredictor = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options()).Train(transformedData); var outNoInitData = notInitPredictor.Transform(transformedData); int numExamples = 10;