Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public API for KMeansPredictor #1739

Merged
merged 12 commits into from
Dec 5, 2018
47 changes: 47 additions & 0 deletions docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.Samples.Dynamic
{
public class KMeans_example
{
public static void KMeans()
{
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
// as well as the source of randomness.
var ml = new MLContext();

// Get a small dataset as an IEnumerable and convert it to an IDataView.
IEnumerable<SamplesUtils.DatasetUtils.SampleInfertData> data = SamplesUtils.DatasetUtils.GetInfertData();
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
var trainData = ml.CreateStreamingDataView(data);

// Preview of the data.
//
// Age Case Education Induced Parity PooledStratum RowNum ...
// 26 1 0-5yrs 1 6 3 1 ...
// 42 1 0-5yrs 1 1 1 2 ...
// 39 1 0-5yrs 2 6 4 3 ...
// 34 1 0-5yrs 2 4 2 4 ...
// 35 1 6-11yrs 1 3 32 5 ...

// A pipeline for concatenating the age, parity and induced columns together in the Features column and training a KMeans model on them.
string outputColumnName = "Features";
var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Age", "Parity", "Induced" })
.Append(ml.Clustering.Trainers.KMeans(outputColumnName, clustersCount: 2));

var model = pipeline.Fit(trainData);

// Get centroids and k from KMeansModelParameters.
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
VBuffer<float>[] centroids = default;
int k;

var modelParams = model.LastTransformer.Model;
modelParams.GetClusterCentroids(ref centroids, out k);

var centroid = centroids[0].GetValues();
Console.WriteLine("The coordinates of centroid 0 are: " + string.Join(", ", centroid.ToArray()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

<ItemGroup>
<ProjectReference Include="..\..\..\src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
Expand Down
6 changes: 4 additions & 2 deletions src/Microsoft.ML.Core/Data/IValueMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace Microsoft.ML.Runtime.Data
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
public interface IValueMapper
[BestFriend]
internal interface IValueMapper
{
ColumnType InputType { get; }
ColumnType OutputType { get; }
Expand All @@ -43,7 +44,8 @@ public interface IValueMapper
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
public interface IValueMapperDist : IValueMapper
[BestFriend]
internal interface IValueMapperDist : IValueMapper
{
ColumnType DistType { get; }

Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Dirty/PredictorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public virtual void Save(ModelSaveContext ctx)
SaveCore(ctx);
}

protected virtual void SaveCore(ModelSaveContext ctx)
[BestFriend]
private protected virtual void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);

Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public interface ISampleableDistribution<TResult> : IDistribution<TResult>
/// <summary>
/// Predictors that can output themselves in a human-readable text format
/// </summary>
public interface ICanSaveInTextFormat
[BestFriend]
internal interface ICanSaveInTextFormat
{
void SaveAsText(TextWriter writer, RoleMappedSchema schema);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator ca
saver?.SaveAsIni(writer, schema, Calibrator);
}

public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
{
// REVIEW: What about the calibrator?
var saver = SubPredictor as ICanSaveInTextFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper
// ValueMapper or FloatPredictor is non-null (or both). With these guarantees,
// the score value type (_scoreType) can be determined.
protected readonly IPredictor Predictor;
protected readonly IValueMapper ValueMapper;
private protected readonly IValueMapper ValueMapper;
protected readonly ColumnType ScoreType;

bool ICanSavePfa.CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLo
return new EnsembleDistributionPredictor(env, ctx);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ct
return new EnsemblePredictor(env, ctx);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadCont
ctx.LoadModel<IOutputCombiner<TOutput>, SignatureLoadModel>(Host, out Combiner, @"Combiner");
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);

Expand Down Expand Up @@ -128,7 +128,7 @@ protected override void SaveCore(ModelSaveContext ctx)
/// <summary>
/// Output the INI model to a given writer
/// </summary>
public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
{
using (var ch = Host.Start("SaveAsText"))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public static EnsembleMultiClassPredictor Create(IHostEnvironment env, ModelLoad
return new EnsembleMultiClassPredictor(env, ctx);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
17 changes: 12 additions & 5 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2829,8 +2829,12 @@ public abstract class FastTreePredictionWrapper :

protected abstract uint VerCategoricalSplitSerialized { get; }

public ColumnType InputType { get; }
public ColumnType OutputType => NumberType.Float;
protected internal readonly ColumnType InputType;
ColumnType IValueMapper.InputType => InputType;

protected readonly ColumnType OutputType;
ColumnType IValueMapper.OutputType => OutputType;

bool ICanSavePfa.CanSavePfa => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

Expand All @@ -2852,6 +2856,7 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsem
Contracts.Assert(NumFeatures > MaxSplitFeatIdx);

InputType = new VectorType(NumberType.Float, NumFeatures);
OutputType = NumberType.Float;
}

protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver)
Expand Down Expand Up @@ -2889,9 +2894,11 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoad
// tricks.

InputType = new VectorType(NumberType.Float, NumFeatures);
OutputType = NumberType.Float;
}

protected override void SaveCore(ModelSaveContext ctx)
[BestFriend]
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);

Expand All @@ -2906,7 +2913,7 @@ protected override void SaveCore(ModelSaveContext ctx)
ctx.Writer.Write(NumFeatures);
}

public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
{
Host.Check(typeof(TIn) == typeof(VBuffer<Float>));
Host.Check(typeof(TOut) == typeof(Float));
Expand Down Expand Up @@ -2964,7 +2971,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema)
/// <summary>
/// Output the INI model to a given writer
/// </summary>
public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
{
Host.CheckValue(writer, nameof(writer));
Host.CheckValueOrNull(schema);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private FastTreeBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ private FastTreeRankingPredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ private FastTreeRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ private FastTreeTweediePredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
16 changes: 9 additions & 7 deletions src/Microsoft.ML.FastTree/GamTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ public abstract class GamPredictorBase : PredictorBase<float>,
public readonly double Intercept;
private readonly int _numFeatures;
private readonly ColumnType _inputType;
private readonly ColumnType _outputType;
// These would be the bins for a totally sparse input.
private readonly int[] _binsAtAllZero;
// The output value for all zeros
Expand All @@ -640,9 +641,8 @@ public abstract class GamPredictorBase : PredictorBase<float>,
private readonly int _inputLength;
private readonly Dictionary<int, int> _inputFeatureToDatasetFeatureMap;

public ColumnType InputType => _inputType;

public ColumnType OutputType => NumberType.Float;
ColumnType IValueMapper.InputType => _inputType;
ColumnType IValueMapper.OutputType => _outputType;

private protected GamPredictorBase(IHostEnvironment env, string name,
int inputLength, Dataset trainSet, double meanEffect, double[][] binEffects, int[] featureMap)
Expand All @@ -658,6 +658,7 @@ private protected GamPredictorBase(IHostEnvironment env, string name,

_numFeatures = binEffects.Length;
_inputType = new VectorType(NumberType.Float, _inputLength);
_outputType = NumberType.Float;
_featureMap = featureMap;

Intercept = meanEffect;
Expand Down Expand Up @@ -762,6 +763,7 @@ protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext c
}

_inputType = new VectorType(NumberType.Float, _inputLength);
_outputType = NumberType.Float;
}

public override void Save(ModelSaveContext ctx)
Expand Down Expand Up @@ -975,7 +977,7 @@ public double[] GetFeatureWeights(int featureIndex)
return featureWeights;
}

public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
{
Host.CheckValue(writer, nameof(writer));
Host.CheckValueOrNull(schema);
Expand Down Expand Up @@ -1018,7 +1020,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema)

public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
{
SaveAsText(writer, schema);
((ICanSaveInTextFormat)this).SaveAsText(writer, schema);
}

/// <summary>
Expand Down Expand Up @@ -1097,7 +1099,7 @@ private sealed class Context
/// These are the number of input features, as opposed to the number of features used within GAM
/// which may be lower.
/// </summary>
public int NumFeatures => _pred.InputType.VectorSize;
public int NumFeatures => _pred._inputType.VectorSize;

public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluator eval)
{
Expand Down Expand Up @@ -1317,7 +1319,7 @@ private FeatureInfo(Context context, int index, int internalIndex, int[] catsMap
public static FeatureInfo GetInfoForIndex(Context context, int index)
{
Contracts.AssertValue(context);
Contracts.Assert(0 <= index && index < context._pred.InputType.ValueCount);
Contracts.Assert(0 <= index && index < context._pred._inputType.ValueCount);
lock (context._pred)
{
int internalIndex;
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)]

[assembly: WantsToBeBestFriends]
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private FastForestRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx
_quantileSampleCount = ctx.Reader.ReadInt32();
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
Double tss = 0; // total sum of squares
using (var cursor = cursorFactory.Create())
{
var lrPredictor = new LinearRegressionPredictor(Host, in weights, bias);
IValueMapper lrPredictor = new LinearRegressionPredictor(Host, in weights, bias);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IValueMapper lrPredictor [](start = 16, length = 24)

Good stuff

var lrMap = lrPredictor.GetMapper<VBuffer<float>, float>();
float yh = default;
while (cursor.MoveNext())
Expand Down Expand Up @@ -682,7 +682,7 @@ private OlsLinearRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx)
ProbCheckDecode(_pValues[i]);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.KMeansClustering;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Trainers.KMeans;
using System;

namespace Microsoft.ML
Expand Down
Loading