Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer estimator cleanup for FastTrees and LightGBM #1352

Merged
merged 12 commits into from
Oct 27, 2018
37 changes: 35 additions & 2 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim

public abstract PredictionKind PredictionKind { get; }

public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null)
public TrainerEstimatorBase(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null)
{
Contracts.CheckValue(host, nameof(host));
Host = host;
Expand Down Expand Up @@ -149,9 +152,39 @@ protected TTransformer TrainTransformer(IDataView trainSet,

protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);

private RoleMappedData MakeRoles(IDataView data) =>
protected virtual RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);

IPredictor ITrainer.Train(TrainContext context) => Train(context);
}

/// <summary>
/// This represents a basic class for 'simple trainer'.
/// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column.
/// It produces a 'prediction transformer'.
/// </summary>
public abstract class TrainerEstimatorBaseWithGroupId<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>
/// The optional groupID column that the ranking trainers expects.
/// </summary>
public readonly SchemaShape.Column GroupIdColumn;

public TrainerEstimatorBaseWithGroupId(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null,
SchemaShape.Column groupId = null)
:base(host, feature, label, weight)
{
Host.CheckValueOrNull(groupId);
GroupIdColumn = groupId;
}

protected override RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name);

}
}
73 changes: 11 additions & 62 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn)
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
/// </summary>
/// <param name="labelColumn">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
/// <param name="columnName">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarColumn(string columnName)
{
if (columnName == null)
Copy link
Member

@singlis singlis Oct 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check needed? It looks like SchemaShape.Column constructor also checks the name for null or empty string. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, because the check inside the constructor will throw if we pass null.


In reply to: 227989462 [](ancestors = 227989462)

return null;

return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
}

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the feature column.
Expand All @@ -377,69 +382,13 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
/// The <see cref="SchemaShape.Column"/> for the weight column.
/// </summary>
/// <param name="weightColumn">name of the weight column</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
/// <param name="isExplicit">whether the column is implicitly, or explicitly defined</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true)
{
if (weightColumn == null)
if (weightColumn == null || !isExplicit)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue)
{
if (argValue != defaultColName)
throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
}

/// <summary>
/// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);

if (args.GroupIdColumn != null)
CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn);
}

/// <summary>
/// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);
}

/// <summary>
/// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
}

/// <summary>
/// If, after applying the advancedArgs delegate, the args are different that the default value
/// and are also different than the value supplied directly to the xtension method, warn the user.
/// </summary>
public static void CheckArgsAndAdvancedSettingMismatch<T>(IChannel channel, T methodParam, T defaultVal, T setting, string argName)
{
if (!setting.Equals(defaultVal) && !setting.Equals(methodParam))
channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}");
}
}

/// <summary>
Expand Down
20 changes: 17 additions & 3 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
{
}

protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
protected BoostingFastTreeTrainerBase(IHostEnvironment env,
SchemaShape.Column label,
string featureColumn,
string weightColumn,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDocumentsInLeafs,
double learningRate,
Action<TArgs> advancedSettings)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings)
{

if (Args.LearningRates != learningRate)
{
using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments."))
Args.LearningRates = learningRate;
Copy link
Contributor

@Zruty0 Zruty0 Oct 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Args [](start = 16, length = 4)

indenting #Resolved

}
}

protected override void CheckArgs(IChannel ch)
Expand Down
57 changes: 21 additions & 36 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.Conversion;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
Expand Down Expand Up @@ -45,7 +46,7 @@ internal static class FastTreeShared
}

public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
TrainerEstimatorBase<TTransformer, TModel>
TrainerEstimatorBaseWithGroupId<TTransformer, TModel>
where TTransformer: ISingleFeaturePredictionTransformer<TModel>
where TArgs : TreeArgs, new()
where TModel : IPredictorProducing<Float>
Expand Down Expand Up @@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
/// <summary>
/// Constructor to use when instantiating the classes deriving from here through the API.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
private protected FastTreeTrainerBase(IHostEnvironment env,
SchemaShape.Column label,
string featureColumn,
string weightColumn,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDocumentsInLeafs,
Action<TArgs> advancedSettings)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
Args = new TArgs();

// set up the directly provided values
// override with the directly provided values.
Args.NumLeaves = numLeaves;
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDocumentsInLeafs;

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

if (weightColumn != null)
Args.WeightColumn = weightColumn;
Args.WeightColumn = Optional<string>.Explicit(weightColumn); ;

if (groupIdColumn != null)
Args.GroupIdColumn = groupIdColumn;
Args.GroupIdColumn = Optional<string>.Explicit(groupIdColumn); ;

// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
Expand All @@ -128,7 +139,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
Host.CheckValue(args, nameof(args));
Args = args;
Expand Down Expand Up @@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel()
return Float.PositiveInfinity;
}

/// <summary>
/// If, after applying the advancedSettings delegate, the args are different that the default value
/// and are also different than the value supplied directly to the xtension method, warn the user
/// about which value is being used.
/// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
/// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
/// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
/// </summary>
protected void CheckArgsAndAdvancedSettingMismatch(int numLeaves,
int numTrees,
int minDocumentsInLeafs,
double learningRate,
BoostedTreeArgs snapshot,
BoostedTreeArgs currentArgs)
{
using (var ch = Host.Start("Comparing advanced settings with the directly provided values."))
{

// Check that the user didn't supply different parameters in the args, from what it specified directly.
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, currentArgs.NumTrees, nameof(numTrees));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, currentArgs.MinDocumentsInLeafs, nameof(minDocumentsInLeafs));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, currentArgs.LearningRates, nameof(learningRate));
}
}

private void Initialize(IHostEnvironment env)
{
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
Expand Down
Loading