diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs
index 30670cd96e..d956532921 100644
--- a/src/Microsoft.ML.PCA/PcaTransform.cs
+++ b/src/Microsoft.ML.PCA/PcaTransform.cs
@@ -1,12 +1,12 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
using System;
+using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
@@ -15,48 +15,48 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Numeric;
+using Microsoft.ML.StaticPipe;
+using Microsoft.ML.StaticPipe.Runtime;
+using Microsoft.ML.Transforms;
-[assembly: LoadableClass(PcaTransform.Summary, typeof(PcaTransform), typeof(PcaTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(PcaTransform.Summary, typeof(IDataTransform), typeof(PcaTransform), typeof(PcaTransform.Arguments), typeof(SignatureDataTransform),
PcaTransform.UserName, PcaTransform.LoaderSignature, PcaTransform.ShortName)]
-[assembly: LoadableClass(PcaTransform.Summary, typeof(PcaTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(PcaTransform.Summary, typeof(IDataTransform), typeof(PcaTransform), null, typeof(SignatureLoadDataTransform),
+ PcaTransform.UserName, PcaTransform.LoaderSignature)]
+
+[assembly: LoadableClass(PcaTransform.Summary, typeof(PcaTransform), null, typeof(SignatureLoadModel),
+ PcaTransform.UserName, PcaTransform.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(PcaTransform), null, typeof(SignatureLoadRowMapper),
PcaTransform.UserName, PcaTransform.LoaderSignature)]
[assembly: LoadableClass(typeof(void), typeof(PcaTransform), null, typeof(SignatureEntryPointModule), PcaTransform.LoaderSignature)]
-namespace Microsoft.ML.Runtime.Data
+namespace Microsoft.ML.Transforms
{
///
- public sealed class PcaTransform : OneToOneTransformBase
+ public sealed class PcaTransform : OneToOneTransformerBase
{
- internal static class Defaults
- {
- public const string WeightColumn = null;
- public const int Rank = 20;
- public const int Oversampling = 20;
- public const bool Center = true;
- public const int Seed = 0;
- }
-
public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public Column[] Column;
[Argument(ArgumentType.Multiple, HelpText = "The name of the weight column", ShortName = "weight", Purpose = SpecialPurpose.ColumnName)]
- public string WeightColumn = Defaults.WeightColumn;
+ public string WeightColumn = PcaEstimator.Defaults.WeightColumn;
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k")]
- public int Rank = Defaults.Rank;
+ public int Rank = PcaEstimator.Defaults.Rank;
[Argument(ArgumentType.AtMostOnce, HelpText = "Oversampling parameter for randomized PCA training", ShortName = "over")]
- public int Oversampling = Defaults.Oversampling;
+ public int Oversampling = PcaEstimator.Defaults.Oversampling;
[Argument(ArgumentType.AtMostOnce, HelpText = "If enabled, data is centered to be zero mean")]
- public bool Center = Defaults.Center;
+ public bool Center = PcaEstimator.Defaults.Center;
[Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation")]
- public int Seed = Defaults.Seed;
+ public int Seed = PcaEstimator.Defaults.Seed;
}
public class Column : OneToOneColumn
@@ -98,22 +98,64 @@ public bool TryUnparse(StringBuilder sb)
}
}
+ public sealed class ColumnInfo
+ {
+ public readonly string Input;
+ public readonly string Output;
+ public readonly string WeightColumn;
+ public readonly int Rank;
+ public readonly int Oversampling;
+ public readonly bool Center;
+ public readonly int? Seed;
+
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ /// The column to apply PCA to.
+ /// The output column that contains PCA values.
+ /// The name of the weight column.
+ /// The number of components in the PCA.
+ /// Oversampling parameter for randomized PCA training.
+ /// If enabled, data is centered to be zero mean.
+ /// The seed for random number generation.
+ public ColumnInfo(string input,
+ string output,
+ string weightColumn = PcaEstimator.Defaults.WeightColumn,
+ int rank = PcaEstimator.Defaults.Rank,
+ int overSampling = PcaEstimator.Defaults.Oversampling,
+ bool center = PcaEstimator.Defaults.Center,
+ int? seed = null)
+ {
+ Input = input;
+ Output = output;
+ WeightColumn = weightColumn;
+ Rank = rank;
+ Oversampling = overSampling;
+ Center = center;
+ Seed = seed;
+ Contracts.CheckParam(Oversampling >= 0, nameof(Oversampling), "Oversampling must be non-negative.");
+ Contracts.CheckParam(Rank > 0, nameof(Rank), "Rank must be positive.");
+ }
+ }
+
private sealed class TransformInfo
{
public readonly int Dimension;
public readonly int Rank;
- public Float[][] Eigenvectors;
- public Float[] MeanProjected;
+ public float[][] Eigenvectors;
+ public float[] MeanProjected;
- public TransformInfo(Column item, Arguments args, int d)
+ public ColumnType OutputType => new VectorType(NumberType.Float, Rank);
+
+ public TransformInfo(int rank, int dim)
{
- Dimension = d;
- Rank = item.Rank ?? args.Rank;
- Contracts.CheckUserArg(0 < Rank && Rank <= Dimension, nameof(item.Rank), "Rank must be positive, and at most the dimension of untransformed data");
+ Dimension = dim;
+ Rank = rank;
+ Contracts.CheckParam(0 < Rank && Rank <= Dimension, nameof(Rank), "Rank must be positive, and at most the dimension of untransformed data");
}
- public TransformInfo(ModelLoadContext ctx, int colValueCount)
+ public TransformInfo(ModelLoadContext ctx)
{
Contracts.AssertValue(ctx);
@@ -121,17 +163,15 @@ public TransformInfo(ModelLoadContext ctx, int colValueCount)
// int: Dimension
// int: Rank
// for i=0,..,Rank-1:
- // Float[]: the i'th eigenvector
+ // float[]: the i'th eigenvector
// int: the size of MeanProjected (0 if it is null)
- // Float[]: MeanProjected
+ // float[]: MeanProjected
Dimension = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(Dimension == colValueCount);
-
Rank = ctx.Reader.ReadInt32();
Contracts.CheckDecode(0 < Rank && Rank <= Dimension);
- Eigenvectors = new Float[Rank][];
+ Eigenvectors = new float[Rank][];
for (int i = 0; i < Rank; i++)
{
Eigenvectors[i] = ctx.Reader.ReadFloatArray(Dimension);
@@ -150,9 +190,9 @@ public void Save(ModelSaveContext ctx)
// int: Dimension
// int: Rank
// for i=0,..,Rank-1:
- // Float[]: the i'th eigenvector
+ // float[]: the i'th eigenvector
// int: the size of MeanProjected (0 if it is null)
- // Float[]: MeanProjected
+ // float[]: MeanProjected
Contracts.Assert(0 < Rank && Rank <= Dimension);
ctx.Writer.Write(Dimension);
@@ -166,7 +206,7 @@ public void Save(ModelSaveContext ctx)
ctx.Writer.WriteFloatArray(MeanProjected);
}
- internal void ProjectMean(Float[] mean)
+ public void ProjectMean(float[] mean)
{
Contracts.AssertValue(Eigenvectors);
if (mean == null)
@@ -175,7 +215,7 @@ internal void ProjectMean(Float[] mean)
return;
}
- MeanProjected = new Float[Rank];
+ MeanProjected = new float[Rank];
for (var i = 0; i < Rank; ++i)
MeanProjected[i] = VectorUtils.DotProduct(Eigenvectors[i], mean);
}
@@ -190,62 +230,41 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "PCA FUNC",
- verWrittenCur: 0x00010001, // Initial
- verReadableCur: 0x00010001,
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Got rid of writing float size in model context
+ verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(PcaTransform).Assembly.FullName);
}
- // These are parallel to Infos.
- private readonly ColumnType[] _types;
+ private readonly int _numColumns;
+ private readonly Mapper.ColumnSchemaInfo[] _schemaInfos;
private readonly TransformInfo[] _transformInfos;
- private readonly int[] _oversampling;
- private readonly bool[] _center;
- private readonly int[] _weightColumnIndex;
-
private const string RegistrationName = "Pca";
- ///
- /// Public constructor corresponding to SignatureDataTransform.
- ///
- public PcaTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column,
- input, TestIsFloatItem)
+ internal PcaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(PcaTransform)), GetColumnPairs(columns))
{
- Host.AssertNonEmpty(Infos);
- Host.Assert(Infos.Length == Utils.Size(args.Column));
+ Host.AssertNonEmpty(ColumnPairs);
+ _numColumns = columns.Length;
+ _transformInfos = new TransformInfo[_numColumns];
+ _schemaInfos = new Mapper.ColumnSchemaInfo[_numColumns];
- _transformInfos = new TransformInfo[args.Column.Length];
- _oversampling = new int[args.Column.Length];
- _center = new bool[args.Column.Length];
- _weightColumnIndex = new int[args.Column.Length];
- for (int i = 0; i < _transformInfos.Length; i++)
+ for (int i = 0; i < _numColumns; i++)
{
- Host.Check(Infos[i].TypeSrc.VectorSize > 1, "Pca transform can only be applied to columns with known dimensionality greater than 1");
- _transformInfos[i] = new TransformInfo(args.Column[i], args, Infos[i].TypeSrc.ValueCount);
- _center[i] = args.Column[i].Center ?? args.Center;
- _oversampling[i] = args.Column[i].Oversampling ?? args.Oversampling;
- Host.CheckUserArg(_oversampling[i] >= 0, nameof(args.Oversampling), "Oversampling must be non-negative");
- _weightColumnIndex[i] = -1;
- var weightColumn = args.Column[i].WeightColumn ?? args.WeightColumn;
- if (weightColumn != null)
- {
- if (!Source.Schema.TryGetColumnIndex(weightColumn, out _weightColumnIndex[i]))
- throw Host.Except("weight column '{0}' does not exist", weightColumn);
- var type = Source.Schema.GetColumnType(_weightColumnIndex[i]);
- Host.CheckUserArg(type == NumberType.Float, nameof(args.WeightColumn));
- }
+ var colInfo = columns[i];
+ var sInfo = _schemaInfos[i] = new Mapper.ColumnSchemaInfo(ColumnPairs[i], input.Schema, colInfo.WeightColumn);
+ ValidatePcaInput(Host, colInfo.Input, sInfo.InputType);
+ _transformInfos[i] = new TransformInfo(colInfo.Rank, sInfo.InputType.ValueCount);
}
- Train(args, _transformInfos, input);
-
- _types = InitColumnTypes();
+ Train(columns, _transformInfos, input);
}
- private PcaTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, TestIsFloatItem)
+ private PcaTransform(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
{
Host.AssertValue(ctx);
@@ -253,27 +272,53 @@ private PcaTransform(IHost host, ModelLoadContext ctx, IDataView input)
//
//
// transformInfos
- Host.AssertNonEmpty(Infos);
- _transformInfos = new TransformInfo[Infos.Length];
- for (int i = 0; i < Infos.Length; i++)
- _transformInfos[i] = new TransformInfo(ctx, Infos[i].TypeSrc.ValueCount);
- _types = InitColumnTypes();
+ Host.AssertNonEmpty(ColumnPairs);
+ _numColumns = ColumnPairs.Length;
+ _transformInfos = new TransformInfo[_numColumns];
+ for (int i = 0; i < _numColumns; i++)
+ _transformInfos[i] = new TransformInfo(ctx);
}
- public static PcaTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
+ // Factory method for SignatureDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+ env.CheckValue(args.Column, nameof(args.Column));
+ var cols = args.Column.Select(item => new ColumnInfo(
+ item.Source,
+ item.Name,
+ item.WeightColumn,
+ item.Rank ?? args.Rank,
+ item.Oversampling ?? args.Oversampling,
+ item.Center ?? args.Center,
+ item.Seed ?? args.Seed)).ToArray();
+ return new PcaTransform(env, input, cols).MakeDataTransform(input);
+ }
- // *** Binary format ***
- // int: sizeof(Float)
- //
- int cbFloat = ctx.Reader.ReadInt32();
- h.CheckDecode(cbFloat == sizeof(Float));
- return h.Apply("Loading Model", ch => new PcaTransform(h, ctx, input));
+ // Factory method for SignatureLoadModel.
+ private static PcaTransform Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(nameof(PcaTransform));
+
+ host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ if (ctx.Header.ModelVerWritten == 0x00010001)
+ {
+ int cbFloat = ctx.Reader.ReadInt32();
+ env.CheckDecode(cbFloat == sizeof(float));
+ }
+ return new PcaTransform(host, ctx);
}
public override void Save(ModelSaveContext ctx)
@@ -283,54 +328,56 @@ public override void Save(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
- // int: sizeof(Float)
//
// transformInfos
- ctx.Writer.Write(sizeof(Float));
- SaveBase(ctx);
+ SaveColumns(ctx);
for (int i = 0; i < _transformInfos.Length; i++)
_transformInfos[i].Save(ctx);
}
-
- private void Train(Arguments args, TransformInfo[] transformInfos, IDataView trainingData)
+ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
{
- var y = new Float[transformInfos.Length][][];
- var omega = new Float[transformInfos.Length][][];
- var mean = new Float[transformInfos.Length][];
+ Contracts.CheckValue(columns, nameof(columns));
+ return columns.Select(x => (x.Input, x.Output)).ToArray();
+ }
- var oversampledRank = new int[transformInfos.Length];
+ private void Train(ColumnInfo[] columns, TransformInfo[] transformInfos, IDataView trainingData)
+ {
+ var y = new float[_numColumns][][];
+ var omega = new float[_numColumns][][];
+ var mean = new float[_numColumns][];
+ var oversampledRank = new int[_numColumns];
var rnd = Host.Rand;
Double totalMemoryUsageEstimate = 0;
- for (int iinfo = 0; iinfo < transformInfos.Length; iinfo++)
+ for (int iinfo = 0; iinfo < _numColumns; iinfo++)
{
- oversampledRank[iinfo] = Math.Min(transformInfos[iinfo].Rank + _oversampling[iinfo], transformInfos[iinfo].Dimension);
+ oversampledRank[iinfo] = Math.Min(transformInfos[iinfo].Rank + columns[iinfo].Oversampling, transformInfos[iinfo].Dimension);
//exact: (size of the 2 big matrices + other minor allocations) / (2^30)
- Double colMemoryUsageEstimate = 2.0 * transformInfos[iinfo].Dimension * oversampledRank[iinfo] * sizeof(Float) / 1e9;
+ Double colMemoryUsageEstimate = 2.0 * transformInfos[iinfo].Dimension * oversampledRank[iinfo] * sizeof(float) / 1e9;
totalMemoryUsageEstimate += colMemoryUsageEstimate;
if (colMemoryUsageEstimate > 2)
{
using (var ch = Host.Start("Memory usage"))
{
ch.Info("Estimate memory usage for transforming column {1}: {0:G2} GB. If running out of memory, reduce rank and oversampling factor.",
- colMemoryUsageEstimate, Infos[iinfo].Name);
+ colMemoryUsageEstimate, ColumnPairs[iinfo].input);
}
}
- y[iinfo] = new Float[oversampledRank[iinfo]][];
- omega[iinfo] = new Float[oversampledRank[iinfo]][];
+ y[iinfo] = new float[oversampledRank[iinfo]][];
+ omega[iinfo] = new float[oversampledRank[iinfo]][];
for (int i = 0; i < oversampledRank[iinfo]; i++)
{
- y[iinfo][i] = new Float[transformInfos[iinfo].Dimension];
- omega[iinfo][i] = new Float[transformInfos[iinfo].Dimension];
+ y[iinfo][i] = new float[transformInfos[iinfo].Dimension];
+ omega[iinfo][i] = new float[transformInfos[iinfo].Dimension];
for (int j = 0; j < transformInfos[iinfo].Dimension; j++)
{
- omega[iinfo][i][j] = (Float)Stats.SampleFromGaussian(rnd);
+ omega[iinfo][i][j] = (float)Stats.SampleFromGaussian(rnd);
}
}
- if (_center[iinfo])
- mean[iinfo] = new Float[transformInfos[iinfo].Dimension];
+ if (columns[iinfo].Center)
+ mean[iinfo] = new float[transformInfos[iinfo].Dimension];
}
if (totalMemoryUsageEstimate > 2)
{
@@ -365,15 +412,15 @@ private void Train(Arguments args, TransformInfo[] transformInfos, IDataView tra
for (int iinfo = 0; iinfo < transformInfos.Length; iinfo++)
{
//Compute B2 = B' * B
- var b2 = new Float[oversampledRank[iinfo] * oversampledRank[iinfo]];
+ var b2 = new float[oversampledRank[iinfo] * oversampledRank[iinfo]];
for (var i = 0; i < oversampledRank[iinfo]; ++i)
{
for (var j = i; j < oversampledRank[iinfo]; ++j)
b2[i * oversampledRank[iinfo] + j] = b2[j * oversampledRank[iinfo] + i] = VectorUtils.DotProduct(b[iinfo][i], b[iinfo][j]);
}
- Float[] smallEigenvalues; // eigenvectors and eigenvalues of the small matrix B2.
- Float[] smallEigenvectors;
+ float[] smallEigenvalues; // eigenvectors and eigenvalues of the small matrix B2.
+ float[] smallEigenvectors;
EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors);
transformInfos[iinfo].Eigenvectors = PostProcess(b[iinfo], smallEigenvalues, smallEigenvectors, transformInfos[iinfo].Dimension, oversampledRank[iinfo]);
@@ -384,9 +431,9 @@ private void Train(Arguments args, TransformInfo[] transformInfos, IDataView tra
//Project the covariance matrix A on to Omega: Y <- A * Omega
//A = X' * X / n, where X = data - mean
//Note that the covariance matrix is not computed explicitly
- private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega, Float[][][] y, TransformInfo[] transformInfos)
+ private void Project(IDataView trainingData, float[][] mean, float[][][] omega, float[][][] y, TransformInfo[] transformInfos)
{
- Host.Assert(mean.Length == omega.Length && omega.Length == y.Length && y.Length == Infos.Length);
+ Host.Assert(mean.Length == omega.Length && omega.Length == y.Length && y.Length == _numColumns);
for (int i = 0; i < omega.Length; i++)
Contracts.Assert(omega[i].Length == y[i].Length);
@@ -399,37 +446,35 @@ private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega,
bool[] center = Enumerable.Range(0, mean.Length).Select(i => mean[i] != null).ToArray();
- Double[] totalColWeight = new Double[Infos.Length];
+ Double[] totalColWeight = new Double[_numColumns];
- bool[] activeColumns = new bool[Source.Schema.ColumnCount];
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
+ bool[] activeColumns = new bool[trainingData.Schema.ColumnCount];
+ foreach (var sInfo in _schemaInfos)
{
- activeColumns[Infos[iinfo].Source] = true;
- if (_weightColumnIndex[iinfo] >= 0)
- activeColumns[_weightColumnIndex[iinfo]] = true;
+ activeColumns[sInfo.InputIndex] = true;
+ if (sInfo.WeightColumnIndex >= 0)
+ activeColumns[sInfo.WeightColumnIndex] = true;
}
+
using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
{
- var weightGetters = new ValueGetter[Infos.Length];
- var columnGetters = new ValueGetter>[Infos.Length];
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
+ var weightGetters = new ValueGetter[_numColumns];
+ var columnGetters = new ValueGetter>[_numColumns];
+ for (int iinfo = 0; iinfo < _numColumns; iinfo++)
{
- if (_weightColumnIndex[iinfo] >= 0)
- weightGetters[iinfo] = cursor.GetGetter(_weightColumnIndex[iinfo]);
- columnGetters[iinfo] = cursor.GetGetter>(Infos[iinfo].Source);
+ var sInfo = _schemaInfos[iinfo];
+ if (sInfo.WeightColumnIndex >= 0)
+ weightGetters[iinfo] = cursor.GetGetter(sInfo.WeightColumnIndex);
+ columnGetters[iinfo] = cursor.GetGetter>(sInfo.InputIndex);
}
- var features = default(VBuffer);
+ var features = default(VBuffer);
while (cursor.MoveNext())
{
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
+ for (int iinfo = 0; iinfo < _numColumns; iinfo++)
{
- Contracts.Check(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsNumber,
- "PCA transform can only be performed on numeric columns of dimension > 1");
-
- Float weight = 1;
- if (weightGetters[iinfo] != null)
- weightGetters[iinfo](ref weight);
+ float weight = 1;
+ weightGetters[iinfo]?.Invoke(ref weight);
columnGetters[iinfo](ref features);
if (FloatUtils.IsFinite(weight) && weight >= 0 && (features.Count == 0 || FloatUtils.IsFinite(features.Values, features.Count)))
@@ -445,15 +490,15 @@ private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega,
}
}
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
+ for (int iinfo = 0; iinfo < _numColumns; iinfo++)
{
if (totalColWeight[iinfo] <= 0)
- throw Host.Except("Empty data in column '{0}'", Source.Schema.GetColumnName(Infos[iinfo].Source));
+ throw Host.Except("Empty data in column '{0}'", ColumnPairs[iinfo].input);
}
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
+ for (int iinfo = 0; iinfo < _numColumns; iinfo++)
{
- var invn = (Float)(1 / totalColWeight[iinfo]);
+ var invn = (float)(1 / totalColWeight[iinfo]);
for (var i = 0; i < omega[iinfo].Length; ++i)
VectorUtils.ScaleBy(y[iinfo][i], invn);
@@ -470,13 +515,13 @@ private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega,
//return Y * eigenvectors / eigenvalues
// REVIEW: improve
- private Float[][] PostProcess(Float[][] y, Float[] sigma, Float[] z, int d, int k)
+ private float[][] PostProcess(float[][] y, float[] sigma, float[] z, int d, int k)
{
- var pinv = new Float[k];
- var tmp = new Float[k];
+ var pinv = new float[k];
+ var tmp = new float[k];
for (int i = 0; i < k; i++)
- pinv[i] = (Float)(1.0) / ((Float)(1e-6) + sigma[i]);
+ pinv[i] = (float)(1.0) / ((float)(1e-6) + sigma[i]);
for (int i = 0; i < d; i++)
{
@@ -493,56 +538,109 @@ private Float[][] PostProcess(Float[][] y, Float[] sigma, Float[] z, int d, int
return y;
}
- private ColumnType[] InitColumnTypes()
+ protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, Schema.Create(schema));
+
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Host.Assert(Infos.Length == _transformInfos.Length);
- var types = new ColumnType[Infos.Length];
- for (int i = 0; i < _transformInfos.Length; i++)
- types[i] = new VectorType(NumberType.Float, _transformInfos[i].Rank);
- Metadata.Seal();
- return types;
+ ValidatePcaInput(Host, inputSchema.GetColumnName(srcCol), inputSchema.GetColumnType(srcCol));
}
- protected override ColumnType GetColumnTypeCore(int iinfo)
+ internal static void ValidatePcaInput(IExceptionContext ectx, string name, ColumnType type)
{
- Host.Check(0 <= iinfo & iinfo < Utils.Size(_types));
- return _types[iinfo];
+ string inputSchema; // just used for the excpections
+
+ if (!(type.IsKnownSizeVector && type.VectorSize > 1 && type.ItemType.Equals(NumberType.R4)))
+ throw ectx.ExceptSchemaMismatch(nameof(inputSchema), "input", name, "vector of floats with fixed size greater than 1", type.ToString());
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ private sealed class Mapper : MapperBase
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
- disposer = null;
-
- var getSrc = GetSrcGetter>(input, iinfo);
- var src = default(VBuffer);
- var trInfo = _transformInfos[iinfo];
- ValueGetter> del =
- (ref VBuffer dst) =>
+ public sealed class ColumnSchemaInfo
+ {
+ public ColumnType InputType { get; }
+ public int InputIndex { get; }
+ public int WeightColumnIndex { get; }
+
+ public ColumnSchemaInfo((string input, string output) columnPair, Schema schema, string weightColumn = null)
{
- getSrc(ref src);
- TransformFeatures(Host, ref src, ref dst, trInfo);
- };
- return del;
- }
+ schema.TryGetColumnIndex(columnPair.input, out int inputIndex);
+ InputIndex = inputIndex;
+ InputType = schema[columnPair.input].Type;
- private static void TransformFeatures(IExceptionContext ectx, ref VBuffer src, ref VBuffer dst, TransformInfo transformInfo)
- {
- ectx.Check(src.Length == transformInfo.Dimension);
+ var weightIndex = -1;
+ if (weightColumn != null)
+ {
+ if (!schema.TryGetColumnIndex(weightColumn, out weightIndex))
+ throw Contracts.Except("Weight column '{0}' does not exist.", weightColumn);
+ Contracts.CheckParam(schema[weightIndex].Type == NumberType.Float, nameof(weightColumn));
+ }
+ WeightColumnIndex = weightIndex;
+ }
+ }
- var values = dst.Values;
- if (Utils.Size(values) < transformInfo.Rank)
- values = new Float[transformInfo.Rank];
+ private readonly PcaTransform _parent;
+ private readonly int _numColumns;
- for (int i = 0; i < transformInfo.Rank; i++)
+ public Mapper(PcaTransform parent, Schema inputSchema)
+ : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
{
- values[i] = VectorUtils.DotProductWithOffset(transformInfo.Eigenvectors[i], 0, ref src) -
- (transformInfo.MeanProjected == null ? 0 : transformInfo.MeanProjected[i]);
+ _parent = parent;
+ _numColumns = parent._numColumns;
+ for (int i = 0; i < _numColumns; i++)
+ {
+ var colPair = _parent.ColumnPairs[i];
+ var colSchemaInfo = new ColumnSchemaInfo(colPair, inputSchema);
+ ValidatePcaInput(Host, colPair.input, colSchemaInfo.InputType);
+ if (colSchemaInfo.InputType.VectorSize != _parent._transformInfos[i].Dimension)
+ {
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input,
+ new VectorType(NumberType.R4, _parent._transformInfos[i].Dimension).ToString(), colSchemaInfo.InputType.ToString());
+ }
+ }
}
- dst = new VBuffer(transformInfo.Rank, values, dst.Indices);
+ public override Schema.Column[] GetOutputColumns()
+ {
+ var result = new Schema.Column[_numColumns];
+ for (int i = 0; i < _numColumns; i++)
+ result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _parent._transformInfos[i].OutputType, null);
+ return result;
+ }
+
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer)
+ {
+ Contracts.AssertValue(input);
+ Contracts.Assert(0 <= iinfo && iinfo < _numColumns);
+ disposer = null;
+
+ var srcGetter = input.GetGetter>(ColMapNewToOld[iinfo]);
+ var src = default(VBuffer);
+
+ ValueGetter> dstGetter = (ref VBuffer dst) =>
+ {
+ srcGetter(ref src);
+ TransformFeatures(Host, ref src, ref dst, _parent._transformInfos[iinfo]);
+ };
+
+ return dstGetter;
+ }
+
+ private static void TransformFeatures(IExceptionContext ectx, ref VBuffer src, ref VBuffer dst, TransformInfo transformInfo)
+ {
+ ectx.Check(src.Length == transformInfo.Dimension);
+
+ var values = dst.Values;
+ if (Utils.Size(values) < transformInfo.Rank)
+ values = new float[transformInfo.Rank];
+
+ for (int i = 0; i < transformInfo.Rank; i++)
+ {
+ values[i] = VectorUtils.DotProductWithOffset(transformInfo.Eigenvectors[i], 0, ref src) -
+ (transformInfo.MeanProjected == null ? 0 : transformInfo.MeanProjected[i]);
+ }
+
+ dst = new VBuffer(transformInfo.Rank, values, dst.Indices);
+ }
}
[TlcModule.EntryPoint(Name = "Transforms.PcaCalculator",
@@ -554,7 +652,7 @@ private static void TransformFeatures(IExceptionContext ectx, ref VBuffer
public static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "Pca", input);
- var view = new PcaTransform(h, input, input.Data);
+ var view = PcaTransform.Create(h, input, input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, view, input.Data),
@@ -562,4 +660,125 @@ public static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Argu
};
}
}
+
+ public sealed class PcaEstimator : IEstimator
+ {
+ internal static class Defaults
+ {
+ public const string WeightColumn = null;
+ public const int Rank = 20;
+ public const int Oversampling = 20;
+ public const bool Center = true;
+ public const int Seed = 0;
+ }
+
+ private readonly IHost _host;
+ private readonly PcaTransform.ColumnInfo[] _columns;
+
+ /// Convinence constructor for simple one column case.
+ ///
+ /// The environment.
+ /// Input column to apply PCA on.
+ /// Output column. Null means is replaced.
+ /// The name of the weight column.
+ /// The number of components in the PCA.
+ /// Oversampling parameter for randomized PCA training.
+ /// If enabled, data is centered to be zero mean.
+ /// The seed for random number generation.
+ public PcaEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null,
+ string weightColumn = Defaults.WeightColumn, int rank = Defaults.Rank,
+ int overSampling = Defaults.Oversampling, bool center = Defaults.Center,
+ int? seed = null)
+ : this(env, new PcaTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, weightColumn, rank, overSampling, center, seed))
+ {
+ }
+
+ public PcaEstimator(IHostEnvironment env, params PcaTransform.ColumnInfo[] columns)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(PcaEstimator));
+ _columns = columns;
+ }
+
+ public PcaTransform Fit(IDataView input) => new PcaTransform(_host, input, _columns);
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ _host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in _columns)
+ {
+ if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
+
+ if (col.Kind != SchemaShape.Column.VectorKind.Vector || !col.ItemType.Equals(NumberType.R4))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
+
+ result[colInfo.Output] = new SchemaShape.Column(colInfo.Output,
+ SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
+ }
+
+ return new SchemaShape(result.Values);
+ }
+ }
+
+ public static class PcaEstimatorExtensions
+ {
+ private sealed class OutPipelineColumn : Vector
+ {
+ public readonly Vector Input;
+
+ public OutPipelineColumn(Vector input, string weightColumn, int rank,
+ int overSampling, bool center, int? seed = null)
+ : base(new Reconciler(weightColumn, rank, overSampling, center, seed), input)
+ {
+ Input = input;
+ }
+ }
+
+ private sealed class Reconciler : EstimatorReconciler
+ {
+ private readonly PcaTransform.ColumnInfo _colInfo;
+
+ public Reconciler(string weightColumn, int rank, int overSampling, bool center, int? seed = null)
+ {
+ _colInfo = new PcaTransform.ColumnInfo(
+ null, null, weightColumn, rank, overSampling, center, seed);
+ }
+
+ public override IEstimator Reconcile(IHostEnvironment env,
+ PipelineColumn[] toOutput,
+ IReadOnlyDictionary inputNames,
+ IReadOnlyDictionary outputNames,
+ IReadOnlyCollection usedNames)
+ {
+ Contracts.Assert(toOutput.Length == 1);
+ var outCol = (OutPipelineColumn)toOutput[0];
+ var inputColName = inputNames[outCol.Input];
+ var outputColName = outputNames[outCol];
+ return new PcaEstimator(env, inputColName, outputColName,
+ _colInfo.WeightColumn, _colInfo.Rank, _colInfo.Oversampling,
+ _colInfo.Center, _colInfo.Seed);
+ }
+ }
+
+ ///
+ /// Replaces the input vector with its projection to the principal component subspace,
+ /// which can significantly reduce size of vector.
+ ///
+ ///
+ /// The column to apply PCA to.
+ /// The name of the weight column.
+ /// The number of components in the PCA.
+ /// Oversampling parameter for randomized PCA training.
+ /// If enabled, data is centered to be zero mean.
+ /// The seed for random number generation
+ /// Vector containing the principal components.
+ public static Vector ToPrincipalComponents(this Vector input,
+ string weightColumn = PcaEstimator.Defaults.WeightColumn,
+ int rank = PcaEstimator.Defaults.Rank,
+ int overSampling = PcaEstimator.Defaults.Oversampling,
+ bool center = PcaEstimator.Defaults.Center,
+ int? seed = null) => new OutPipelineColumn(input, weightColumn, rank, overSampling, center, seed);
+ }
}
diff --git a/src/Microsoft.ML.PCA/WrappedPcaTransform.cs b/src/Microsoft.ML.PCA/WrappedPcaTransform.cs
deleted file mode 100644
index 1f082e193e..0000000000
--- a/src/Microsoft.ML.PCA/WrappedPcaTransform.cs
+++ /dev/null
@@ -1,116 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using Microsoft.ML.Core.Data;
-using Microsoft.ML.Runtime.Internal.Utilities;
-using Microsoft.ML.StaticPipe;
-using Microsoft.ML.StaticPipe.Runtime;
-using System;
-using System.Collections.Generic;
-using System.Linq;
-
-namespace Microsoft.ML.Runtime.Data
-{
- ///
- public sealed class PcaEstimator : TrainedWrapperEstimatorBase
- {
- private readonly PcaTransform.Arguments _args;
-
- ///
- /// The environment.
- /// Input column to apply PCA on.
- /// Output column. Null means is replaced.
- /// The number of components in the PCA.
- /// A delegate to apply all the advanced arguments to the algorithm.
- public PcaEstimator(IHostEnvironment env,
- string inputColumn,
- string outputColumn = null,
- int rank = PcaTransform.Defaults.Rank,
- Action advancedSettings = null)
- : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, rank, advancedSettings)
- {
- }
-
- ///
- /// The environment.
- /// Pairs of columns to run the PCA on.
- /// The number of components in the PCA.
- /// A delegate to apply all the advanced arguments to the algorithm.
- public PcaEstimator(IHostEnvironment env, (string input, string output)[] columns,
- int rank = PcaTransform.Defaults.Rank,
- Action advancedSettings = null)
- : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(PcaEstimator)))
- {
- foreach (var (input, output) in columns)
- {
- Host.CheckUserArg(Utils.Size(input) > 0, nameof(input));
- Host.CheckValue(output, nameof(input));
- }
-
- _args = new PcaTransform.Arguments();
- _args.Column = columns.Select(x => new PcaTransform.Column { Source = x.input, Name = x.output }).ToArray();
- _args.Rank = rank;
-
- advancedSettings?.Invoke(_args);
- }
-
- public override TransformWrapper Fit(IDataView input)
- {
- return new TransformWrapper(Host, new PcaTransform(Host, _args, input));
- }
- }
-
- ///
- /// Extensions for statically typed .
- ///
- public static class PcaEstimatorExtensions
- {
- private sealed class OutPipelineColumn : Vector
- {
- public readonly Vector Input;
-
- public OutPipelineColumn(Vector input, int rank, Action advancedSettings)
- : base(new Reconciler(null, rank, advancedSettings), input)
- {
- Input = input;
- }
- }
-
- private sealed class Reconciler : EstimatorReconciler
- {
- private readonly int _rank;
- private readonly Action _advancedSettings;
-
- public Reconciler(PipelineColumn weightColumn, int rank, Action advancedSettings)
- {
- _rank = rank;
- _advancedSettings = advancedSettings;
- }
-
- public override IEstimator Reconcile(IHostEnvironment env,
- PipelineColumn[] toOutput,
- IReadOnlyDictionary inputNames,
- IReadOnlyDictionary outputNames,
- IReadOnlyCollection usedNames)
- {
- Contracts.Assert(toOutput.Length == 1);
-
- var pairs = new List<(string input, string output)>();
- foreach (var outCol in toOutput)
- pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol]));
-
- return new PcaEstimator(env, pairs.ToArray(), _rank, _advancedSettings);
- }
- }
-
- /// Replace current vector with its principal components. Can significantly reduce size of vector.
- ///
- /// The column to apply PCA to.
- /// The number of components in the PCA.
- /// A delegate to apply all the advanced arguments to the algorithm.
- public static Vector ToPrincipalComponents(this Vector input,
- int rank = PcaTransform.Defaults.Rank,
- Action advancedSettings = null) => new OutPipelineColumn(input, rank, advancedSettings);
- }
-}
diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
index 7095eeaaed..bb0ca58a38 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
+++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
@@ -119,7 +119,7 @@ Transforms.ModelCombiner Combines a sequence of TransformModels into a single mo
Transforms.NGramTranslator Produces a bag of counts of ngrams (sequences of consecutive values of length 1-n) in a given vector of keys. It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. Microsoft.ML.Runtime.Transforms.TextAnalytics NGramTransform Microsoft.ML.Runtime.Data.NgramTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
Transforms.NoOperation Does nothing. Microsoft.ML.Runtime.Data.NopTransform Nop Microsoft.ML.Runtime.Data.NopTransform+NopInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
Transforms.OptionalColumnCreator If the source column does not exist after deserialization, create a column with the right type and default values. Microsoft.ML.Runtime.DataPipe.OptionalColumnTransform MakeOptional Microsoft.ML.Runtime.DataPipe.OptionalColumnTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
-Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Runtime.Data.PcaTransform Calculate Microsoft.ML.Runtime.Data.PcaTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
+Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Transforms.PcaTransform Calculate Microsoft.ML.Transforms.PcaTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
Transforms.PredictedLabelColumnOriginalValueConverter Transforms a predicted label column to its original values, unless it is of type bool. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner ConvertPredictedLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+PredictedLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
Transforms.RandomNumberGenerator Adds a column with a generated number sequence. Microsoft.ML.Runtime.Data.RandomNumberGenerator Generate Microsoft.ML.Runtime.Data.GenerateNumberTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
Transforms.RowRangeFilter Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values. Microsoft.ML.Runtime.EntryPoints.SelectRows FilterByRange Microsoft.ML.Runtime.Data.RangeFilter+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
index 6d4cae0995..7234173e9f 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
@@ -856,5 +856,25 @@ public void TextNormalizeStatic()
type = schema.GetColumnType(numbers);
Assert.True(!type.IsVector && type.ItemType.IsText);
}
+
+ [Fact]
+ public void TestPcaStatic()
+ {
+ var env = new ConsoleEnvironment(seed: 1);
+ var dataSource = GetDataPath("generated_regression_dataset.csv");
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
+ separator: ';', hasHeader: true);
+ var data = reader.Read(dataSource);
+ var est = reader.MakeNewEstimator()
+ .Append(r => (r.label, pca: r.features.ToPrincipalComponents(rank: 5)));
+ var tdata = est.Fit(data).Transform(data);
+ var schema = tdata.AsDynamic.Schema;
+
+ Assert.True(schema.TryGetColumnIndex("pca", out int pca));
+ var type = schema[pca].Type;
+ Assert.True(type.IsVector && type.ItemType.RawKind == DataKind.R4);
+ Assert.True(type.VectorSize == 5);
+ }
}
}
\ No newline at end of file
diff --git a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs
index 5561b4df7d..8f1089e0dc 100644
--- a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs
@@ -2,11 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.IO;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Transforms;
-using System.IO;
using Xunit;
using Xunit.Abstractions;
@@ -14,43 +14,57 @@ namespace Microsoft.ML.Tests.Transformers
{
public sealed class PcaTests : TestDataPipeBase
{
+ private readonly ConsoleEnvironment _env;
+ private readonly string _dataSource;
+ private readonly TextSaver _saver;
+
public PcaTests(ITestOutputHelper helper)
: base(helper)
{
+ _env = new ConsoleEnvironment(seed: 1);
+ _dataSource = GetDataPath("generated_regression_dataset.csv");
+ _saver = new TextSaver(_env, new TextSaver.Arguments { Silent = true, OutputHeader = false });
}
[Fact]
public void PcaWorkout()
{
- var env = new ConsoleEnvironment(seed: 1, conc: 1);
- string dataSource = GetDataPath("generated_regression_dataset.csv");
- var data = TextLoader.CreateReader(env,
- c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
+ var data = TextLoader.CreateReader(_env,
+ c => (label: c.LoadFloat(11), weight: c.LoadFloat(0), features: c.LoadFloat(1, 10)),
separator: ';', hasHeader: true)
- .Read(dataSource);
+ .Read(_dataSource);
- var invalidData = TextLoader.CreateReader(env,
- c => (label: c.LoadFloat(11), features: c.LoadText(0, 10)),
+ var invalidData = TextLoader.CreateReader(_env,
+ c => (label: c.LoadFloat(11), weight: c.LoadFloat(0), features: c.LoadText(1, 10)),
separator: ';', hasHeader: true)
- .Read(dataSource);
+ .Read(_dataSource);
+
+ var est = new PcaEstimator(_env, "features", "pca", rank: 4, seed: 10);
+ TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic);
+
+ var estNonDefaultArgs = new PcaEstimator(_env, "features", "pca", rank: 3, weightColumn: "weight", overSampling: 2, center: false);
+ TestEstimatorCore(estNonDefaultArgs, data.AsDynamic, invalidInput: invalidData.AsDynamic);
- var est = new PcaEstimator(env, "features", "pca", rank: 5, advancedSettings: s => {
- s.Seed = 1;
- });
+ Done();
+ }
- // The following call fails because of the following issue
- // https://github.com/dotnet/machinelearning/issues/969
- // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic);
+ [Fact]
+ public void TestPcaEstimator()
+ {
+ var data = TextLoader.CreateReader(_env,
+ c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
+ separator: ';', hasHeader: true)
+ .Read(_dataSource);
+ var est = new PcaEstimator(_env, "features", "pca", rank: 5, seed: 1);
var outputPath = GetOutputPath("PCA", "pca.tsv");
- using (var ch = env.Start("save"))
+ using (var ch = _env.Start("save"))
{
- var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false });
- IDataView savedData = TakeFilter.Create(env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4);
- savedData = new ChooseColumnsTransform(env, savedData, "pca");
+ IDataView savedData = TakeFilter.Create(_env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4);
+ savedData = new ChooseColumnsTransform(_env, savedData, "pca");
using (var fs = File.Create(outputPath))
- DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true);
+ DataSaverUtils.SaveDataView(ch, _saver, savedData, fs, keepHidden: true);
}
CheckEquality("PCA", "pca.tsv");