Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public API for KMeansPredictor #1739

Merged
merged 12 commits into from
Dec 5, 2018
47 changes: 47 additions & 0 deletions docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.Samples.Dynamic
{
public class KMeans_example
{
public static void KMeans()
{
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
// as well as the source of randomness.
var ml = new MLContext();

// Get a small dataset as an IEnumerable and convert it to an IDataView.
IEnumerable<SamplesUtils.DatasetUtils.SampleInfertData> data = SamplesUtils.DatasetUtils.GetInfertData();
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
var trainData = ml.CreateStreamingDataView(data);

// Preview of the data.
//
// Age Case Education Induced Parity PooledStratum RowNum ...
// 26 1 0-5yrs 1 6 3 1 ...
// 42 1 0-5yrs 1 1 1 2 ...
// 39 1 0-5yrs 2 6 4 3 ...
// 34 1 0-5yrs 2 4 2 4 ...
// 35 1 6-11yrs 1 3 32 5 ...

// A pipeline for concatenating the age, parity and induced columns together in the Features column and training a KMeans model on them.
string outputColumnName = "Features";
var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Age", "Parity", "Induced" })
.Append(ml.Clustering.Trainers.KMeans(outputColumnName, clustersCount: 2));

var model = pipeline.Fit(trainData);

// Get centroids and k from KMeansModelParameters.
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
VBuffer<float>[] centroids = default;
int k;

var modelParams = model.LastTransformer.Model;
modelParams.GetClusterCentroids(ref centroids, out k);

var centroid = centroids[0].GetValues();
Console.WriteLine("The coordinates of centroid 0 are: " + string.Join(", ", centroid.ToArray()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

<ItemGroup>
<ProjectReference Include="..\..\..\src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
Expand Down
6 changes: 4 additions & 2 deletions src/Microsoft.ML.Core/Data/IValueMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace Microsoft.ML.Runtime.Data
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
public interface IValueMapper
[BestFriend]
internal interface IValueMapper
{
ColumnType InputType { get; }
ColumnType OutputType { get; }
Expand All @@ -43,7 +44,8 @@ public interface IValueMapper
/// type arguments for GetMapper, but typically contain additional information like
/// vector lengths.
/// </summary>
public interface IValueMapperDist : IValueMapper
[BestFriend]
internal interface IValueMapperDist : IValueMapper
{
ColumnType DistType { get; }

Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Dirty/PredictorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public virtual void Save(ModelSaveContext ctx)
SaveCore(ctx);
}

protected virtual void SaveCore(ModelSaveContext ctx)
[BestFriend]
private protected virtual void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);

Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public interface ISampleableDistribution<TResult> : IDistribution<TResult>
/// <summary>
/// Predictors that can output themselves in a human-readable text format
/// </summary>
public interface ICanSaveInTextFormat
[BestFriend]
internal interface ICanSaveInTextFormat
{
void SaveAsText(TextWriter writer, RoleMappedSchema schema);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper
// ValueMapper or FloatPredictor is non-null (or both). With these guarantees,
// the score value type (_scoreType) can be determined.
protected readonly IPredictor Predictor;
protected readonly IValueMapper ValueMapper;
private protected readonly IValueMapper ValueMapper;
protected readonly ColumnType ScoreType;

bool ICanSavePfa.CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLo
return new EnsembleDistributionPredictor(env, ctx);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ct
return new EnsemblePredictor(env, ctx);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadCont
ctx.LoadModel<IOutputCombiner<TOutput>, SignatureLoadModel>(Host, out Combiner, @"Combiner");
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public static EnsembleMultiClassPredictor Create(IHostEnvironment env, ModelLoad
return new EnsembleMultiClassPredictor(env, ctx);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2891,7 +2891,8 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoad
InputType = new VectorType(NumberType.Float, NumFeatures);
}

protected override void SaveCore(ModelSaveContext ctx)
[BestFriend]
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private FastTreeBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ private FastTreeRankingPredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ private FastTreeRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ private FastTreeTweediePredictor(IHostEnvironment env, ModelLoadContext ctx)
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.TestValue)]

[assembly: WantsToBeBestFriends]
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext
{
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private FastForestRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx
_quantileSampleCount = ctx.Reader.ReadInt32();
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ private OlsLinearRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx)
ProbCheckDecode(_pValues[i]);
}

protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.KMeansClustering;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Trainers.KMeans;
using System;

namespace Microsoft.ML
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,31 @@

using Float = System.Single;

using System;
using System.IO;
using Microsoft.ML.Runtime.Numeric;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.KMeansClustering;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Trainers.KMeans;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Numeric;
using System;
using System.IO;
using System.Collections.Generic;

[assembly: LoadableClass(typeof(KMeansPredictor), null, typeof(SignatureLoadModel),
"KMeans predictor", KMeansPredictor.LoaderSignature)]
[assembly: LoadableClass(typeof(KMeansModelParameters), null, typeof(SignatureLoadModel),
"KMeans predictor", KMeansModelParameters.LoaderSignature)]

namespace Microsoft.ML.Trainers.KMeans
namespace Microsoft.ML.KMeansClustering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Microsoft.ML.KMeansClustering [](start = 10, length = 29)

i wonder if we want to put all predictors in something:

Microsoft.ML.ModelParameters.KMeansClustering

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the argument is that the namespace is "Trainers" and therefore only ITrainer implementors should live in that namespace, then I'd argue that we should not have "Trainers" in the namespace name, if you think it could cause confusion.

It would be very inconvenient to have predictors in a different namespace from their associated trainers (except in generalizable cases like linear predictors that can be produced by many things).


In reply to: 238843887 [](ancestors = 238843887)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Microsoft.ML.KMeans would be a better namespace, but let's not have one namespace every time we devise a new type that isn't a trainer, a predictor, or whatever. 😛

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So both the trainer and predictor would be in Microsoft.ML.KMeansClustering. That is what I gather from @Zruty0 's description in #1699 .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On #1318 there was an explicit ask to put all the trainers inside Microsoft.ML.Trainers or appropriate namaspaces.

So if i read both those issues, the ModelParameters also go in the same namespace as the respective trainer; for this case Microsoft.ML.Trainers.Kmeans @TomFinley that sounds good?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I'll stick to the conventions in #1318 and revert the namespace to Microsoft.ML.Trainers.KMeans.

{
public sealed class KMeansPredictor :
public sealed class KMeansModelParameters :
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
PredictorBase<VBuffer<Float>>,
IValueMapper,
ICanSaveInTextFormat,
ICanSaveModel,
ISingleCanSaveOnnx
{
public const string LoaderSignature = "KMeansPredictor";
internal const string LoaderSignature = "KMeansPredictor";

/// <summary>
/// Version information to be saved in binary format
Expand All @@ -42,12 +42,16 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(KMeansPredictor).Assembly.FullName);
loaderAssemblyName: typeof(KMeansModelParameters).Assembly.FullName);
}

// REVIEW: Leaving this public for now until we figure out the correct way to remove it.
public override PredictionKind PredictionKind => PredictionKind.Clustering;
public ColumnType InputType { get; }
public ColumnType OutputType { get; }

private readonly ColumnType _inputType;
private readonly ColumnType _outputType;
ColumnType IValueMapper.InputType => _inputType;
ColumnType IValueMapper.OutputType => _outputType;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

Expand All @@ -66,7 +70,7 @@ private static VersionInfo GetVersionInfo()
/// a deep copy, if false then this constructor will take ownership of the passed in centroid vectors.
/// If false then the caller must take care to not use or modify the input vectors once this object
/// is constructed, and should probably remove all references.</param>
public KMeansPredictor(IHostEnvironment env, int k, VBuffer<float>[] centroids, bool copyIn)
public KMeansModelParameters(IHostEnvironment env, int k, VBuffer<float>[] centroids, bool copyIn)
: base(env, LoaderSignature)
{
Host.CheckParam(k > 0, nameof(k), "Need at least one cluster");
Expand All @@ -92,16 +96,16 @@ public KMeansPredictor(IHostEnvironment env, int k, VBuffer<float>[] centroids,

InitPredictor();

InputType = new VectorType(NumberType.Float, _dimensionality);
OutputType = new VectorType(NumberType.Float, _k);
_inputType = new VectorType(NumberType.Float, _dimensionality);
_outputType = new VectorType(NumberType.Float, _k);
}

/// <summary>
/// Initialize predictor from a binary file.
/// </summary>
/// <param name="ctx">The load context</param>
/// <param name="env">The host environment</param>
private KMeansPredictor(IHostEnvironment env, ModelLoadContext ctx)
private KMeansModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx)
{
// *** Binary format ***
Expand Down Expand Up @@ -134,11 +138,11 @@ private KMeansPredictor(IHostEnvironment env, ModelLoadContext ctx)

InitPredictor();

InputType = new VectorType(NumberType.Float, _dimensionality);
OutputType = new VectorType(NumberType.Float, _k);
_inputType = new VectorType(NumberType.Float, _dimensionality);
_outputType = new VectorType(NumberType.Float, _k);
}

public ValueMapper<TIn, TOut> GetMapper<TIn, TOut>()
ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
najeeb-kazmi marked this conversation as resolved.
Show resolved Hide resolved
{
Host.Check(typeof(TIn) == typeof(VBuffer<Float>));
Host.Check(typeof(TOut) == typeof(VBuffer<Float>));
Expand Down Expand Up @@ -169,7 +173,7 @@ private void Map(in VBuffer<Float> src, Span<Float> distances)
}
}

public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it private?

{
writer.WriteLine("K: {0}", _k);
writer.WriteLine("Dimensionality: {0}", _dimensionality);
Expand Down Expand Up @@ -215,7 +219,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
/// Save the predictor in binary format.
/// </summary>
/// <param name="ctx">The context to save to</param>
protected override void SaveCore(ModelSaveContext ctx)
private protected override void SaveCore(ModelSaveContext ctx)
{
base.SaveCore(ctx);
ctx.SetVersionInfo(GetVersionInfo());
Expand Down Expand Up @@ -247,12 +251,12 @@ protected override void SaveCore(ModelSaveContext ctx)
/// <summary>
/// This method is called by reflection to instantiate a predictor.
/// </summary>
public static KMeansPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
private static KMeansModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new KMeansPredictor(env, ctx);
return new KMeansModelParameters(env, ctx);
}

/// <summary>
Expand Down
Loading