diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs new file mode 100644 index 0000000000..c25e75b0e2 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs @@ -0,0 +1,47 @@ +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Samples.Dynamic +{ + public class KMeans_example + { + public static void KMeans() + { + // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, + // as well as the source of randomness. + var ml = new MLContext(); + + // Get a small dataset as an IEnumerable and convert it to an IDataView. + var data = SamplesUtils.DatasetUtils.GetInfertData(); + var trainData = ml.CreateStreamingDataView(data); + + // Preview of the data. + // + // Age Case Education Induced Parity PooledStratum RowNum ... + // 26 1 0-5yrs 1 6 3 1 ... + // 42 1 0-5yrs 1 1 1 2 ... + // 39 1 0-5yrs 2 6 4 3 ... + // 34 1 0-5yrs 2 4 2 4 ... + // 35 1 6-11yrs 1 3 32 5 ... + + // A pipeline for concatenating the age, parity and induced columns together in the Features column and training a KMeans model on them. + string outputColumnName = "Features"; + var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Age", "Parity", "Induced" }) + .Append(ml.Clustering.Trainers.KMeans(outputColumnName, clustersCount: 2)); + + var model = pipeline.Fit(trainData); + + // Get cluster centroids and the number of clusters k from KMeansModelParameters. + VBuffer[] centroids = default; + int k; + + var modelParams = model.LastTransformer.Model; + modelParams.GetClusterCentroids(ref centroids, out k); + + var centroid = centroids[0].GetValues(); + Console.WriteLine("The coordinates of centroid 0 are: " + string.Join(", ", centroid.ToArray())); + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj index 6b3d9f957b..a8002fe9bf 100644 --- a/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj +++ b/docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj @@ -7,6 +7,7 @@ + diff --git a/src/Microsoft.ML.Core/Data/IValueMapper.cs b/src/Microsoft.ML.Core/Data/IValueMapper.cs index c6abfbc02d..d0b8d4b755 100644 --- a/src/Microsoft.ML.Core/Data/IValueMapper.cs +++ b/src/Microsoft.ML.Core/Data/IValueMapper.cs @@ -24,7 +24,8 @@ namespace Microsoft.ML.Runtime.Data /// type arguments for GetMapper, but typically contain additional information like /// vector lengths. /// - public interface IValueMapper + [BestFriend] + internal interface IValueMapper { ColumnType InputType { get; } ColumnType OutputType { get; } @@ -43,7 +44,8 @@ public interface IValueMapper /// type arguments for GetMapper, but typically contain additional information like /// vector lengths. /// - public interface IValueMapperDist : IValueMapper + [BestFriend] + internal interface IValueMapperDist : IValueMapper { ColumnType DistType { get; } diff --git a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs b/src/Microsoft.ML.Data/Dirty/PredictorBase.cs index 0d73db05fc..fe27705ff8 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorBase.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorBase.cs @@ -53,7 +53,8 @@ public virtual void Save(ModelSaveContext ctx) SaveCore(ctx); } - protected virtual void SaveCore(ModelSaveContext ctx) + [BestFriend] + private protected virtual void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs index f866ca4817..af9eddb332 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs @@ -92,7 +92,8 @@ public interface ISampleableDistribution : IDistribution /// /// Predictors that can output themselves in a human-readable text format /// - public interface ICanSaveInTextFormat + [BestFriend] + internal interface ICanSaveInTextFormat { void SaveAsText(TextWriter writer, RoleMappedSchema schema); } diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 26ae54bd0b..bff0f79442 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -159,7 +159,7 @@ public void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator ca saver?.SaveAsIni(writer, schema, Calibrator); } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanSaveInTextFormat; diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index a2fdf96af4..248e7936c8 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -41,7 +41,7 @@ public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper // ValueMapper or FloatPredictor is non-null (or both). With these guarantees, // the score value type (_scoreType) can be determined. protected readonly IPredictor Predictor; - protected readonly IValueMapper ValueMapper; + private protected readonly IValueMapper ValueMapper; protected readonly ColumnType ScoreType; bool ICanSavePfa.CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs index 5a92512e6d..6346abfc80 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs @@ -109,7 +109,7 @@ public static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLo return new EnsembleDistributionPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs index b69b76f40e..30b2dce057 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs @@ -99,7 +99,7 @@ public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ct return new EnsemblePredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs index 032312acde..72a81b12e3 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs @@ -86,7 +86,7 @@ protected EnsemblePredictorBase(IHostEnvironment env, string name, ModelLoadCont ctx.LoadModel, SignatureLoadModel>(Host, out Combiner, @"Combiner"); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); @@ -128,7 +128,7 @@ protected override void SaveCore(ModelSaveContext ctx) /// /// Output the INI model to a given writer /// - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { using (var ch = Host.Start("SaveAsText")) { diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs index 6f931f3e9c..a332d3ddb0 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs @@ -95,7 +95,7 @@ public static EnsembleMultiClassPredictor Create(IHostEnvironment env, ModelLoad return new EnsembleMultiClassPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 21a5737944..9cc6e1666c 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2829,8 +2829,12 @@ public abstract class FastTreePredictionWrapper : protected abstract uint VerCategoricalSplitSerialized { get; } - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; + protected internal readonly ColumnType InputType; + ColumnType IValueMapper.InputType => InputType; + + protected readonly ColumnType OutputType; + ColumnType IValueMapper.OutputType => OutputType; + bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; @@ -2852,6 +2856,7 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsem Contracts.Assert(NumFeatures > MaxSplitFeatIdx); InputType = new VectorType(NumberType.Float, NumFeatures); + OutputType = NumberType.Float; } protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver) @@ -2889,9 +2894,11 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoad // tricks. InputType = new VectorType(NumberType.Float, NumFeatures); + OutputType = NumberType.Float; } - protected override void SaveCore(ModelSaveContext ctx) + [BestFriend] + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); @@ -2906,7 +2913,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(NumFeatures); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(Float)); @@ -2964,7 +2971,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) /// /// Output the INI model to a given writer /// - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValueOrNull(schema); diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index e628c0dce9..da89d8394f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -79,7 +79,7 @@ private FastTreeBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index af4c90a20b..b3967ef5db 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -1125,7 +1125,7 @@ private FastTreeRankingPredictor(IHostEnvironment env, ModelLoadContext ctx) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 1ba21904c0..a101284957 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -474,7 +474,7 @@ private FastTreeRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 45668b56a9..65f6151ea6 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -477,7 +477,7 @@ private FastTreeTweediePredictor(IHostEnvironment env, ModelLoadContext ctx) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index b38f64d381..d073c40b5b 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -631,6 +631,7 @@ public abstract class GamPredictorBase : PredictorBase, public readonly double Intercept; private readonly int _numFeatures; private readonly ColumnType _inputType; + private readonly ColumnType _outputType; // These would be the bins for a totally sparse input. private readonly int[] _binsAtAllZero; // The output value for all zeros @@ -640,9 +641,8 @@ public abstract class GamPredictorBase : PredictorBase, private readonly int _inputLength; private readonly Dictionary _inputFeatureToDatasetFeatureMap; - public ColumnType InputType => _inputType; - - public ColumnType OutputType => NumberType.Float; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => _outputType; private protected GamPredictorBase(IHostEnvironment env, string name, int inputLength, Dataset trainSet, double meanEffect, double[][] binEffects, int[] featureMap) @@ -658,6 +658,7 @@ private protected GamPredictorBase(IHostEnvironment env, string name, _numFeatures = binEffects.Length; _inputType = new VectorType(NumberType.Float, _inputLength); + _outputType = NumberType.Float; _featureMap = featureMap; Intercept = meanEffect; @@ -762,6 +763,7 @@ protected GamPredictorBase(IHostEnvironment env, string name, ModelLoadContext c } _inputType = new VectorType(NumberType.Float, _inputLength); + _outputType = NumberType.Float; } public override void Save(ModelSaveContext ctx) @@ -975,7 +977,7 @@ public double[] GetFeatureWeights(int featureIndex) return featureWeights; } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValueOrNull(schema); @@ -1018,7 +1020,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema) public void SaveSummary(TextWriter writer, RoleMappedSchema schema) { - SaveAsText(writer, schema); + ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); } /// @@ -1097,7 +1099,7 @@ private sealed class Context /// These are the number of input features, as opposed to the number of features used within GAM /// which may be lower. /// - public int NumFeatures => _pred.InputType.VectorSize; + public int NumFeatures => _pred._inputType.VectorSize; public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluator eval) { @@ -1317,7 +1319,7 @@ private FeatureInfo(Context context, int index, int internalIndex, int[] catsMap public static FeatureInfo GetInfoForIndex(Context context, int index) { Contracts.AssertValue(context); - Contracts.Assert(0 <= index && index < context._pred.InputType.ValueCount); + Contracts.Assert(0 <= index && index < context._pred._inputType.ValueCount); lock (context._pred) { int internalIndex; diff --git a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs index a03d7bdab6..cd27563c10 100644 --- a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs @@ -6,5 +6,6 @@ using Microsoft.ML; [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)] [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 9b7f4f78b1..e70713f515 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -88,7 +88,7 @@ private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index a5fec96249..ab7e0f362b 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -76,7 +76,7 @@ private FastForestRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx _quantileSampleCount = ctx.Reader.ReadInt32(); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 33eecb3b65..0442118442 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -296,7 +296,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac Double tss = 0; // total sum of squares using (var cursor = cursorFactory.Create()) { - var lrPredictor = new LinearRegressionPredictor(Host, in weights, bias); + IValueMapper lrPredictor = new LinearRegressionPredictor(Host, in weights, bias); var lrMap = lrPredictor.GetMapper, float>(); float yh = default; while (cursor.MoveNext()) @@ -682,7 +682,7 @@ private OlsLinearRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) ProbCheckDecode(_pValues[i]); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs similarity index 89% rename from src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs rename to src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs index 38b5116da4..e5862249bb 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs @@ -4,31 +4,37 @@ using Float = System.Single; -using System; -using System.IO; -using Microsoft.ML.Runtime.Numeric; -using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.KMeans; +using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Model.Onnx; -using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Numeric; +using Microsoft.ML.Trainers.KMeans; +using System; +using System.IO; using System.Collections.Generic; -[assembly: LoadableClass(typeof(KMeansPredictor), null, typeof(SignatureLoadModel), - "KMeans predictor", KMeansPredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(KMeansModelParameters), null, typeof(SignatureLoadModel), + "KMeans predictor", KMeansModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.KMeans { - public sealed class KMeansPredictor : + /// + /// + /// + /// + public sealed class KMeansModelParameters : PredictorBase>, IValueMapper, ICanSaveInTextFormat, ICanSaveModel, ISingleCanSaveOnnx { - public const string LoaderSignature = "KMeansPredictor"; + internal const string LoaderSignature = "KMeansPredictor"; /// /// Version information to be saved in binary format @@ -42,12 +48,16 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(KMeansPredictor).Assembly.FullName); + loaderAssemblyName: typeof(KMeansModelParameters).Assembly.FullName); } + // REVIEW: Leaving this public for now until we figure out the correct way to remove it. public override PredictionKind PredictionKind => PredictionKind.Clustering; - public ColumnType InputType { get; } - public ColumnType OutputType { get; } + + private readonly ColumnType _inputType; + private readonly ColumnType _outputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => _outputType; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; @@ -66,7 +76,7 @@ private static VersionInfo GetVersionInfo() /// a deep copy, if false then this constructor will take ownership of the passed in centroid vectors. /// If false then the caller must take care to not use or modify the input vectors once this object /// is constructed, and should probably remove all references. - public KMeansPredictor(IHostEnvironment env, int k, VBuffer[] centroids, bool copyIn) + public KMeansModelParameters(IHostEnvironment env, int k, VBuffer[] centroids, bool copyIn) : base(env, LoaderSignature) { Host.CheckParam(k > 0, nameof(k), "Need at least one cluster"); @@ -92,8 +102,8 @@ public KMeansPredictor(IHostEnvironment env, int k, VBuffer[] centroids, InitPredictor(); - InputType = new VectorType(NumberType.Float, _dimensionality); - OutputType = new VectorType(NumberType.Float, _k); + _inputType = new VectorType(NumberType.Float, _dimensionality); + _outputType = new VectorType(NumberType.Float, _k); } /// @@ -101,7 +111,7 @@ public KMeansPredictor(IHostEnvironment env, int k, VBuffer[] centroids, /// /// The load context /// The host environment - private KMeansPredictor(IHostEnvironment env, ModelLoadContext ctx) + private KMeansModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** @@ -134,11 +144,11 @@ private KMeansPredictor(IHostEnvironment env, ModelLoadContext ctx) InitPredictor(); - InputType = new VectorType(NumberType.Float, _dimensionality); - OutputType = new VectorType(NumberType.Float, _k); + _inputType = new VectorType(NumberType.Float, _dimensionality); + _outputType = new VectorType(NumberType.Float, _k); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); @@ -169,7 +179,7 @@ private void Map(in VBuffer src, Span distances) } } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine("K: {0}", _k); writer.WriteLine("Dimensionality: {0}", _dimensionality); @@ -215,7 +225,7 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema) /// Save the predictor in binary format. /// /// The context to save to - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -247,12 +257,12 @@ protected override void SaveCore(ModelSaveContext ctx) /// /// This method is called by reflection to instantiate a predictor. /// - public static KMeansPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static KMeansModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new KMeansPredictor(env, ctx); + return new KMeansModelParameters(env, ctx); } /// diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 41084bd5ab..a72d148be7 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -29,7 +29,7 @@ namespace Microsoft.ML.Trainers.KMeans { /// - public class KMeansPlusPlusTrainer : TrainerEstimatorBase, KMeansPredictor> + public class KMeansPlusPlusTrainer : TrainerEstimatorBase, KMeansModelParameters> { public const string LoadNameValue = "KMeansPlusPlus"; internal const string UserNameValue = "KMeans++ Clustering"; @@ -151,7 +151,7 @@ private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action MakeTransformer(KMeansPredictor model, Schema trainSchema) - => new ClusteringPredictionTransformer(Host, model, trainSchema, _featureColumn); + protected override ClusteringPredictionTransformer MakeTransformer(KMeansModelParameters model, Schema trainSchema) + => new ClusteringPredictionTransformer(Host, model, trainSchema, _featureColumn); } internal static class KMeansPlusPlusInit diff --git a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs index f5fe8a91d9..1fb94afc8e 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansStatic.cs @@ -4,8 +4,8 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.KMeans; using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Trainers.KMeans; using System; namespace Microsoft.ML.StaticPipe @@ -33,7 +33,7 @@ public static (Vector score, Key predictedLabel) KMeans(this Cluste Vector features, Scalar weights = null, int clustersCount = KMeansPlusPlusTrainer.Defaults.K, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { Contracts.CheckValue(features, nameof(features)); Contracts.CheckValueOrNull(weights); diff --git a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs index 448395769a..47ff917bf8 100644 --- a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs +++ b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs @@ -4,11 +4,11 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Sweeper; +using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Trainers.KMeans; using Microsoft.ML.Trainers.PCA; -using Microsoft.ML.Runtime.Sweeper; -using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Transforms.Categorical; using System; using System.Reflection; @@ -44,7 +44,7 @@ private static bool LoadStandardAssemblies() _ = typeof(TextLoader).Assembly; // ML.Data //_ = typeof(EnsemblePredictor).Assembly); // ML.Ensemble BUG https://github.com/dotnet/machinelearning/issues/1078 Ensemble isn't in a NuGet package _ = typeof(FastTreeBinaryPredictor).Assembly; // ML.FastTree - _ = typeof(KMeansPredictor).Assembly; // ML.KMeansClustering + _ = typeof(KMeansModelParameters).Assembly; // ML.KMeansClustering _ = typeof(Maml).Assembly; // ML.Maml _ = typeof(PcaPredictor).Assembly; // ML.PCA _ = typeof(SweepCommand).Assembly; // ML.Sweeper diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index ffec7f10ee..159e32ec56 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -65,7 +65,7 @@ private LightGbmBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index e7f76bf215..ba998b8545 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -61,7 +61,7 @@ private LightGbmRankingPredictor(IHostEnvironment env, ModelLoadContext ctx) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index f59d289e3a..062e994520 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -61,7 +61,7 @@ private LightGbmRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) { } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index f6ab01820f..6a48ce5d3f 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -462,7 +462,7 @@ private PcaPredictor(IHostEnvironment env, ModelLoadContext ctx) _inputType = new VectorType(NumberType.Float, _dimension); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -500,10 +500,10 @@ public static PcaPredictor Create(IHostEnvironment env, ModelLoadContext ctx) public void SaveSummary(TextWriter writer, RoleMappedSchema schema) { - SaveAsText(writer, schema); + ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine("Dimension: {0}", _dimension); writer.WriteLine("Rank: {0}", _rank); @@ -550,17 +550,17 @@ public IDataView GetSummaryDataView(RoleMappedSchema schema) return bldr.GetDataView(); } - public ColumnType InputType + ColumnType IValueMapper.InputType { get { return _inputType; } } - public ColumnType OutputType + ColumnType IValueMapper.OutputType { get { return NumberType.Float; } } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(float)); diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs index 71219ee19e..f9c39fa88c 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs @@ -163,7 +163,7 @@ public void Save(ModelSaveContext ctx) /// /// Save the trained matrix factorization model (two factor matrices) in text format /// - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine("# Imputed matrix is P * Q'"); writer.WriteLine("# P in R^({0} x {1}), rows correpond to Y item", _numberOfRows, _approximationRank); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index 91e57b13a9..7c818102bc 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -120,7 +120,7 @@ public static FieldAwareFactorizationMachinePredictor Create(IHostEnvironment en return new FieldAwareFactorizationMachinePredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 5f3c0c72cc..be074d5866 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -97,9 +97,7 @@ IEnumerator IEnumerable.GetEnumerator() /// The predictor's bias term. public Float Bias { get; protected set; } - public ColumnType InputType { get; } - - public ColumnType OutputType => NumberType.Float; + private readonly ColumnType _inputType; bool ICanSavePfa.CanSavePfa => true; @@ -121,7 +119,7 @@ internal LinearPredictor(IHostEnvironment env, string name, in VBuffer we Weight = weights; Bias = bias; - InputType = new VectorType(NumberType.Float, Weight.Length); + _inputType = new VectorType(NumberType.Float, Weight.Length); if (Weight.IsDense) _weightsDense = Weight; @@ -176,7 +174,7 @@ protected LinearPredictor(IHostEnvironment env, string name, ModelLoadContext ct else Weight = new VBuffer(len, Utils.Size(weights), weights, indices); - InputType = new VectorType(NumberType.Float, Weight.Length); + _inputType = new VectorType(NumberType.Float, Weight.Length); WarnOnOldNormalizer(ctx, GetType(), Host); if (Weight.IsDense) @@ -185,7 +183,8 @@ protected LinearPredictor(IHostEnvironment env, string name, ModelLoadContext ct _weightsDenseLock = new object(); } - protected override void SaveCore(ModelSaveContext ctx) + [BestFriend] + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); @@ -283,7 +282,17 @@ private void EnsureWeightsDense() } } - public ValueMapper GetMapper() + ColumnType IValueMapper.InputType + { + get { return _inputType; } + } + + ColumnType IValueMapper.OutputType + { + get { return NumberType.Float; } + } + + ValueMapper IValueMapper.GetMapper() { Contracts.Check(typeof(TIn) == typeof(VBuffer)); Contracts.Check(typeof(TOut) == typeof(Float)); @@ -326,7 +335,7 @@ protected void CombineParameters(IList> models, out VBuff bias /= models.Count; } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); @@ -459,7 +468,7 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC return new SchemaBindableCalibratedPredictor(env, predictor, calibrator); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { // *** Binary format *** // (Base class) @@ -613,7 +622,7 @@ public static LinearRegressionPredictor Create(IHostEnvironment env, ModelLoadCo return new LinearRegressionPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -689,7 +698,7 @@ public static PoissonRegressionPredictor Create(IHostEnvironment env, ModelLoadC return new PoissonRegressionPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index af41587e33..d4bd1f73ae 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -394,8 +394,11 @@ private static VersionInfo GetVersionInfo() private volatile VBuffer[] _weightsDense; public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public ColumnType InputType { get; } - public ColumnType OutputType { get; } + internal readonly ColumnType InputType; + internal readonly ColumnType OutputType; + ColumnType IValueMapper.InputType => InputType; + ColumnType IValueMapper.OutputType => OutputType; + bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; @@ -567,7 +570,7 @@ public static MulticlassLogisticRegressionPredictor Create(IHostEnvironment env, return new MulticlassLogisticRegressionPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -710,7 +713,7 @@ private static int NonZeroCount(in VBuffer vector) return count; } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TSrc) == typeof(VBuffer), "Invalid source type in GetMapper"); Host.Check(typeof(TDst) == typeof(VBuffer), "Invalid destination type in GetMapper"); @@ -781,7 +784,7 @@ private void Calibrate(Span dst) /// /// Output the text model to a given writer /// - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine(nameof(MulticlassLogisticRegression) + " bias and non-zero weights"); @@ -857,7 +860,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) public void SaveSummary(TextWriter writer, RoleMappedSchema schema) { - SaveAsText(writer, schema); + ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); } JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 43ddb8a355..46e964a487 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -209,9 +209,9 @@ private static VersionInfo GetVersionInfo() public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public ColumnType InputType => _inputType; + ColumnType IValueMapper.InputType => _inputType; - public ColumnType OutputType => _outputType; + ColumnType IValueMapper.OutputType => _outputType; /// /// Copies the label histogram into a buffer. @@ -306,7 +306,7 @@ public static MultiClassNaiveBayesPredictor Create(IHostEnvironment env, ModelLo return new MultiClassNaiveBayesPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -352,7 +352,7 @@ private static double[] CalculateAbsentFeatureLogProbabilities(int[] labelHistog return absentFeaturesLogProb; } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index a8894f9149..7f41ba7320 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -226,9 +226,8 @@ private static VersionInfo GetVersionInfo() private readonly ImplBase _impl; public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public ColumnType InputType => _impl.InputType; - public ColumnType OutputType { get; } - public ColumnType DistType => OutputType; + private readonly ColumnType _outputType; + public ColumnType DistType => _outputType; bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; [BestFriend] @@ -279,7 +278,7 @@ private OvaPredictor(IHostEnvironment env, ImplBase impl) Host.Assert(Utils.Size(impl.Predictors) > 0); _impl = impl; - OutputType = new VectorType(NumberType.Float, _impl.Predictors.Length); + _outputType = new VectorType(NumberType.Float, _impl.Predictors.Length); } private OvaPredictor(IHostEnvironment env, ModelLoadContext ctx) @@ -305,7 +304,7 @@ private OvaPredictor(IHostEnvironment env, ModelLoadContext ctx) _impl = new ImplRaw(predictors); } - OutputType = new VectorType(NumberType.Float, _impl.Predictors.Length); + _outputType = new VectorType(NumberType.Float, _impl.Predictors.Length); } public static OvaPredictor Create(IHostEnvironment env, ModelLoadContext ctx) @@ -323,7 +322,7 @@ private static void LoadPredictors(IHostEnvironment env, TPredictor[ ctx.LoadModel(env, out predictors[i], string.Format(SubPredictorFmt, i)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -348,7 +347,16 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) return _impl.SaveAsPfa(ctx, input); } - public ValueMapper GetMapper() + ColumnType IValueMapper.InputType + { + get { return _impl.InputType; } + } + + ColumnType IValueMapper.OutputType + { + get { return _outputType; } + } + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); @@ -376,7 +384,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) } } - public void SaveAsText(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 9f5ab8ea3c..e74d707d68 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -244,8 +244,10 @@ private static VersionInfo GetVersionInfo() private readonly IValueMapperDist[] _mappers; public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; - public ColumnType InputType { get; } - public ColumnType OutputType { get; } + private readonly ColumnType _inputType; + private readonly ColumnType _outputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => _outputType; internal PkpdPredictor(IHostEnvironment env, TDistPredictor[][] predictors) : base(env, RegistrationName) @@ -267,8 +269,8 @@ internal PkpdPredictor(IHostEnvironment env, TDistPredictor[][] predictors) : } Host.Assert(index == _predictors.Length); - InputType = InitializeMappers(out _mappers); - OutputType = new VectorType(NumberType.Float, _numClasses); + _inputType = InitializeMappers(out _mappers); + _outputType = new VectorType(NumberType.Float, _numClasses); } private PkpdPredictor(IHostEnvironment env, ModelLoadContext ctx) @@ -295,8 +297,8 @@ private PkpdPredictor(IHostEnvironment env, ModelLoadContext ctx) Host.Assert(index == GetIndex(i, i)); ctx.LoadModel(Host, out _predictors[index++], string.Format(SubPredictorFmt, i)); } - InputType = InitializeMappers(out _mappers); - OutputType = new VectorType(NumberType.Float, _numClasses); + _inputType = InitializeMappers(out _mappers); + _outputType = new VectorType(NumberType.Float, _numClasses); } private ColumnType InitializeMappers(out IValueMapperDist[] mappers) @@ -337,7 +339,7 @@ public static PkpdPredictor Create(IHostEnvironment env, ModelLoadContext ctx) return new PkpdPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -442,7 +444,7 @@ private int GetIndex(int i, int j) return i * (i + 1) / 2 + j; } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); @@ -455,8 +457,8 @@ public ValueMapper GetMapper() ValueMapper, VBuffer> del = (in VBuffer src, ref VBuffer dst) => { - if (InputType.VectorSize > 0) - Host.Check(src.Length == InputType.VectorSize); + if (_inputType.VectorSize > 0) + Host.Check(src.Length == _inputType.VectorSize); var tmp = src; Parallel.For(0, maps.Length, i => diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index ec39fc1e1c..eaa45068f7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -172,7 +172,7 @@ public static RandomPredictor Create(IHostEnvironment env, ModelLoadContext ctx) /// Save the predictor in the binary format. /// /// - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -395,7 +395,7 @@ public static PriorPredictor Create(IHostEnvironment env, ModelLoadContext ctx) return new PriorPredictor(env, ctx); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index b59c2118b5..e1a5f7be3d 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -5,21 +5,21 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FactorizationMachine; -using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; -using Microsoft.ML.Trainers.KMeans; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.StaticPipe; using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.KMeans; +using Microsoft.ML.Trainers.Recommender; using Microsoft.ML.Transforms.Categorical; using System; using System.Linq; using Xunit; using Xunit.Abstractions; -using Microsoft.ML.Trainers.Recommender; namespace Microsoft.ML.StaticPipelineTesting { @@ -676,7 +676,7 @@ public void KMeans() var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); - KMeansPredictor pred = null; + KMeansModelParameters pred = null; var est = reader.MakeNewEstimator() .Append(r => (label: r.label.ToKey(), r.features)) diff --git a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs index 13a78fdf3f..139baafeaf 100644 --- a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs +++ b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs @@ -23,7 +23,7 @@ public static TEnvironment AddStandardComponents(this TEnvironment env.ComponentCatalog.RegisterAssembly(typeof(OneHotEncodingTransformer).Assembly); // ML.Transforms env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryPredictor).Assembly); // ML.FastTree env.ComponentCatalog.RegisterAssembly(typeof(EnsemblePredictor).Assembly); // ML.Ensemble - env.ComponentCatalog.RegisterAssembly(typeof(KMeansPredictor).Assembly); // ML.KMeansClustering + env.ComponentCatalog.RegisterAssembly(typeof(KMeansModelParameters).Assembly); // ML.KMeansClustering env.ComponentCatalog.RegisterAssembly(typeof(PcaPredictor).Assembly); // ML.PCA env.ComponentCatalog.RegisterAssembly(typeof(Experiment).Assembly); // ML.Legacy return env;