Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing ModelParameter discrepancies #2968

Merged
merged 11 commits into from
Mar 19, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private class BinaryOutputRow
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
=> output.AboveAverage = input.MedianHomeValue > 22.6;

public static float[] GetLinearModelWeights(OrdinaryLeastSquaresRegressionModelParameters linearModel)
public static float[] GetLinearModelWeights(OlsModelParameters linearModel)
{
return linearModel.Weights.ToArray();
}
Expand Down
34 changes: 17 additions & 17 deletions src/Microsoft.ML.FastTree/GamClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
GamBinaryClassificationTrainer.LoadNameValue,
GamBinaryClassificationTrainer.ShortName, DocName = "trainer/GAM.md")]

[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(BinaryClassificationGamModelParameters), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(GamBinaryModelParameters), null, typeof(SignatureLoadModel),
"GAM Binary Class Predictor",
BinaryClassificationGamModelParameters.LoaderSignature)]
GamBinaryModelParameters.LoaderSignature)]

namespace Microsoft.ML.Trainers.FastTree
{
Expand All @@ -32,8 +32,8 @@ namespace Microsoft.ML.Trainers.FastTree
/// <include file='doc.xml' path='doc/members/member[@name="GAM_remarks"]/*' />
public sealed class GamBinaryClassificationTrainer :
GamTrainerBase<GamBinaryClassificationTrainer.Options,
BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>,
CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>,
CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
{
/// <summary>
/// Options for the <see cref="GamBinaryClassificationTrainer"/>.
Expand Down Expand Up @@ -111,13 +111,13 @@ private static bool[] ConvertTargetsToBool(double[] targets)
Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions);
return boolArray;
}
private protected override CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
private protected override CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
{
TrainBase(context);
var predictor = new BinaryClassificationGamModelParameters(Host,
var predictor = new GamBinaryModelParameters(Host,
BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0);
return new ValueMapperCalibratedModelParameters<BinaryClassificationGamModelParameters, PlattCalibrator>(Host, predictor, calibrator);
return new ValueMapperCalibratedModelParameters<GamBinaryModelParameters, PlattCalibrator>(Host, predictor, calibrator);
}

private protected override ObjectiveFunctionBase CreateObjectiveFunction()
Expand Down Expand Up @@ -146,15 +146,15 @@ private protected override void DefinePruningTest()
PruningTest = new TestHistory(validTest, PruningLossIndex);
}

private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
MakeTransformer(CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
MakeTransformer(CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);

/// <summary>
/// Trains a <see cref="GamBinaryClassificationTrainer"/> using both training and validation data, returns
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
/// </summary>
public BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
public BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);

private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand All @@ -171,7 +171,7 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
/// <summary>
/// The model parameters class for Binary Classification GAMs
/// </summary>
public sealed class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float>
public sealed class GamBinaryModelParameters : GamModelParametersBase, IPredictorProducing<float>
Copy link
Member

@eerhardt eerhardt Mar 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BinaryClassification?

Do we need to put Classification in the name? We do everywhere else. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ModelParameters we do not. Please see the other comment.


In reply to: 266626783 [](ancestors = 266626783)

{
internal const string LoaderSignature = "BinaryClassGamPredictor";
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
Expand All @@ -188,11 +188,11 @@ public sealed class BinaryClassificationGamModelParameters : GamModelParametersB
/// <param name="featureToInputMap">A map from the feature shape functions, as described by <paramref name="binUpperBounds"/> and <paramref name="binEffects"/>.
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
/// a shape function.</param>
internal BinaryClassificationGamModelParameters(IHostEnvironment env,
internal GamBinaryModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength, int[] featureToInputMap)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }

private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
private GamBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }

private static VersionInfo GetVersionInfo()
Expand All @@ -205,7 +205,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName);
loaderAssemblyName: typeof(GamBinaryModelParameters).Assembly.FullName);
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
Expand All @@ -214,12 +214,12 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

var predictor = new BinaryClassificationGamModelParameters(env, ctx);
var predictor = new GamBinaryModelParameters(env, ctx);
ICalibrator calibrator;
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
if (calibrator == null)
return predictor;
return new SchemaBindableCalibratedModelParameters<BinaryClassificationGamModelParameters, ICalibrator>(env, predictor, calibrator);
return new SchemaBindableCalibratedModelParameters<GamBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/GamModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -879,12 +879,12 @@ private Context Init(IChannel ch)
// 2. RegressionGamModelParameters
// For (1), the trained model, GamModelParametersBase, is a field we need to extract. For (2),
// we don't need to do anything because RegressionGamModelParameters is derived from GamModelParametersBase.
var calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
var calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
while (calibrated != null)
{
hadCalibrator = true;
rawPred = calibrated.SubModel;
calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
}
var pred = rawPred as GamModelParametersBase;
ch.CheckUserArg(pred != null, nameof(ImplOptions.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase));
Expand Down
28 changes: 14 additions & 14 deletions src/Microsoft.ML.FastTree/GamRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
GamRegressionTrainer.LoadNameValue,
GamRegressionTrainer.ShortName, DocName = "trainer/GAM.md")]

[assembly: LoadableClass(typeof(RegressionGamModelParameters), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(GamRegressionModelParameters), null, typeof(SignatureLoadModel),
"GAM Regression Predictor",
RegressionGamModelParameters.LoaderSignature)]
GamRegressionModelParameters.LoaderSignature)]

namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for training a regression model with generalized additive models (GAM).
/// </summary>
/// <include file='doc.xml' path='doc/members/member[@name="GAM_remarks"]/*' />
public sealed class GamRegressionTrainer : GamTrainerBase<GamRegressionTrainer.Options, RegressionPredictionTransformer<RegressionGamModelParameters>, RegressionGamModelParameters>
public sealed class GamRegressionTrainer : GamTrainerBase<GamRegressionTrainer.Options, RegressionPredictionTransformer<GamRegressionModelParameters>, GamRegressionModelParameters>
{
/// <summary>
/// Options for the <see cref="GamRegressionTrainer"/>.
Expand Down Expand Up @@ -80,10 +80,10 @@ private protected override void CheckLabel(RoleMappedData data)
data.CheckRegressionLabel();
}

private protected override RegressionGamModelParameters TrainModelCore(TrainContext context)
private protected override GamRegressionModelParameters TrainModelCore(TrainContext context)
{
TrainBase(context);
return new RegressionGamModelParameters(Host, BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
return new GamRegressionModelParameters(Host, BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
}

private protected override ObjectiveFunctionBase CreateObjectiveFunction()
Expand All @@ -99,14 +99,14 @@ private protected override void DefinePruningTest()
PruningTest = new TestHistory(validTest, PruningLossIndex);
}

private protected override RegressionPredictionTransformer<RegressionGamModelParameters> MakeTransformer(RegressionGamModelParameters model, DataViewSchema trainSchema)
=> new RegressionPredictionTransformer<RegressionGamModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
private protected override RegressionPredictionTransformer<GamRegressionModelParameters> MakeTransformer(GamRegressionModelParameters model, DataViewSchema trainSchema)
=> new RegressionPredictionTransformer<GamRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name);

/// <summary>
/// Trains a <see cref="GamRegressionTrainer"/> using both training and validation data, returns
/// a <see cref="RegressionPredictionTransformer{RegressionGamModelParameters}"/>.
/// </summary>
public RegressionPredictionTransformer<RegressionGamModelParameters> Fit(IDataView trainData, IDataView validationData)
public RegressionPredictionTransformer<GamRegressionModelParameters> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);

private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
Expand All @@ -121,7 +121,7 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
/// <summary>
/// The model parameters class for Binary Classification GAMs
/// </summary>
public sealed class RegressionGamModelParameters : GamModelParametersBase
public sealed class GamRegressionModelParameters : GamModelParametersBase
{
internal const string LoaderSignature = "RegressionGamPredictor";
private protected override PredictionKind PredictionKind => PredictionKind.Regression;
Expand All @@ -138,11 +138,11 @@ public sealed class RegressionGamModelParameters : GamModelParametersBase
/// <param name="featureToInputMap">A map from the feature shape functions (as described by the binUpperBounds and BinEffects)
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
/// a shape function.</param>
internal RegressionGamModelParameters(IHostEnvironment env,
internal GamRegressionModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength = -1, int[] featureToInputMap = null)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }

private RegressionGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
private GamRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }

private static VersionInfo GetVersionInfo()
Expand All @@ -155,16 +155,16 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(RegressionGamModelParameters).Assembly.FullName);
loaderAssemblyName: typeof(GamRegressionModelParameters).Assembly.FullName);
}

private static RegressionGamModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
private static GamRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return new RegressionGamModelParameters(env, ctx);
return new GamRegressionModelParameters(env, ctx);
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down
Loading