diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index ea141c7187..e2017a8e34 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -12,6 +12,7 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -313,11 +314,13 @@ private protected override void SaveModel(ModelSaveContext ctx) private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly DataViewType[] _srcTypes; private readonly int[] _srcCols; private readonly DataViewType[] _types; + private readonly LpNormNormalizingEstimatorBase.NormFunction[] _norms; + private readonly bool[] _ensureZeroMeans; private readonly LpNormNormalizingTransformer _parent; public Mapper(LpNormNormalizingTransformer parent, DataViewSchema inputSchema) @@ -327,12 +330,16 @@ public Mapper(LpNormNormalizingTransformer parent, DataViewSchema inputSchema) _types = new DataViewType[_parent.ColumnPairs.Length]; _srcTypes = new DataViewType[_parent.ColumnPairs.Length]; _srcCols = new int[_parent.ColumnPairs.Length]; + _norms = new LpNormNormalizingEstimatorBase.NormFunction[_parent.ColumnPairs.Length]; + _ensureZeroMeans = new bool[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]); var srcCol = inputSchema[_srcCols[i]]; _srcTypes[i] = srcCol.Type; _types[i] = srcCol.Type; + _norms[i] = _parent._columns[i].Norm; + _ensureZeroMeans[i] = _parent._columns[i].EnsureZeroMean; } } @@ -594,6 +601,128 @@ private static float Mean(ReadOnlySpan src, int length) return 0; return CpuMathUtils.Sum(src) / length; } + + public bool CanSaveOnnx(OnnxContext ctx) => true; + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + for (int iinfo = 0; iinfo < _srcCols.Length; ++iinfo) + { + string inputColumnName = InputSchema[_srcCols[iinfo]].Name; + if (!ctx.ContainsColumn(inputColumnName)) + { + ctx.RemoveColumn(inputColumnName, false); + continue; + } + + if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(_srcTypes[iinfo], inputColumnName))) + { + ctx.RemoveColumn(inputColumnName, true); + } + } + } + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) + { + string opType; + + if ((_norms[iinfo] != LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation) && (_ensureZeroMeans[iinfo] == false)) + { + string strNorm; + if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L1) + strNorm = "L1"; + else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L2) + strNorm = "L2"; + else + strNorm = "MAX"; + opType = "Normalizer"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("norm", strNorm); + return true; + } + + opType = "ReduceMean"; + string meanOfInput = ctx.AddIntermediateVariable(_types[iinfo], "MeanOfInput", true); + var meanNode = ctx.CreateNode(opType, srcVariableName, meanOfInput, ctx.GetNodeName(opType), ""); + meanNode.AddAttribute("axes", new long[] { 1 }); + + opType = "Sub"; + string inputMinusMean = ctx.AddIntermediateVariable(_types[iinfo], "InputMinusMean"); + var subtractNode = ctx.CreateNode(opType, new[] { srcVariableName, meanOfInput }, new[] { inputMinusMean }, ctx.GetNodeName(opType), ""); + + if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L1) + { + opType = "Abs"; + string absOfInput = ctx.AddIntermediateVariable(_types[iinfo], "AbsOfInput"); + var absNode = ctx.CreateNode(opType, inputMinusMean, absOfInput, ctx.GetNodeName(opType), ""); + + opType = "ReduceSum"; + string sumOfAbsOfInput = ctx.AddIntermediateVariable(_types[iinfo], "SumOfAbsOfInput", true); + var sumOfAbsNode = ctx.CreateNode(opType, absOfInput, sumOfAbsOfInput, ctx.GetNodeName(opType), ""); + sumOfAbsNode.AddAttribute("axes", new long[] { 1 }); + + opType = "Div"; + var l1Node = ctx.CreateNode(opType, new[] { inputMinusMean, sumOfAbsOfInput }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); + } + else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.L2) + { + opType = "Pow"; + string two = ctx.AddInitializer(2.0f); + string squareOfInput = ctx.AddIntermediateVariable(_types[iinfo], "SquareOfInput", true); + var squareNode = ctx.CreateNode(opType, new[] { inputMinusMean, two }, new[] { squareOfInput }, ctx.GetNodeName(opType), ""); + + opType = "ReduceSum"; + string sumOfSquares = ctx.AddIntermediateVariable(_types[iinfo], "SumOfSquares", true); + var sumOfSquaresNode = ctx.CreateNode(opType, squareOfInput, sumOfSquares, ctx.GetNodeName(opType), ""); + sumOfSquaresNode.AddAttribute("axes", new long[] { 1 }); + + opType = "Sqrt"; + string squareRoot = ctx.AddIntermediateVariable(_types[iinfo], "SquareRoot", true); + var squareRootNode = ctx.CreateNode(opType, sumOfSquares, squareRoot, ctx.GetNodeName(opType), ""); + + opType = "Div"; + var l2Node = ctx.CreateNode(opType, new[] { inputMinusMean, squareRoot }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); + } + else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.Infinity) + { + opType = "ReduceMax"; + string maxOfInput = ctx.AddIntermediateVariable(_types[iinfo], "MaxOfInput", true); + var maxNode = ctx.CreateNode(opType, inputMinusMean, maxOfInput, ctx.GetNodeName(opType), ""); + maxNode.AddAttribute("axes", new long[] { 1 }); + + opType = "Div"; + var lMaxNode = ctx.CreateNode(opType, new[] { inputMinusMean, maxOfInput }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); + } + else if (_norms[iinfo] == LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation) + { + // first calculate the standard deviation + opType = "Pow"; + string two = ctx.AddInitializer(2.0f); + string squareOfInputMinusMean = ctx.AddIntermediateVariable(_types[iinfo], "SquareOfInputMinusMean", true); + var squareOfInputMinusMeanNode = ctx.CreateNode(opType, new[] { inputMinusMean, two }, new[] { squareOfInputMinusMean }, ctx.GetNodeName(opType), ""); + + opType = "ReduceMean"; + string average = ctx.AddIntermediateVariable(_types[iinfo], "SumOfSquares", true); + var sumOfSquaresNode = ctx.CreateNode(opType, squareOfInputMinusMean, average, ctx.GetNodeName(opType), ""); + sumOfSquaresNode.AddAttribute("axes", new long[] { 1 }); + + opType = "Sqrt"; + string stdDev = ctx.AddIntermediateVariable(_types[iinfo], "SquareRoot", true); + var stdDevNode = ctx.CreateNode(opType, average, stdDev, ctx.GetNodeName(opType), ""); + + opType = "Div"; + string input = _ensureZeroMeans[iinfo] ? inputMinusMean : srcVariableName; + var lStdDevNode = ctx.CreateNode(opType, new[] {input, stdDev }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); + } + else + { + Contracts.Assert(false); + return false; + } + return true; + } } } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index af87577042..8071dc64ee 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -180,6 +180,68 @@ public void KmeansOnnxConversionTest() Done(); } + private class DataPoint + { + [VectorType(3)] + public float[] Features { get; set; } + } + + [Fact] + void LpNormOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + + var samples = new List() + { + new DataPoint() { Features = new float[3] {0.01f, 0.02f, 0.03f} }, + new DataPoint() { Features = new float[3] {0.04f, 0.05f, 0.06f} }, + new DataPoint() { Features = new float[3] {0.07f, 0.08f, 0.09f} }, + new DataPoint() { Features = new float[3] {0.10f, 0.11f, 0.12f} }, + new DataPoint() { Features = new float[3] {0.13f, 0.14f, 0.15f} } + }; + var dataView = mlContext.Data.LoadFromEnumerable(samples); + + LpNormNormalizingEstimatorBase.NormFunction[] norms = + { + LpNormNormalizingEstimatorBase.NormFunction.L1, + LpNormNormalizingEstimatorBase.NormFunction.L2, + LpNormNormalizingEstimatorBase.NormFunction.Infinity, + LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation + }; + + bool[] ensureZeroMeans = { true, false}; + foreach (var ensureZeroMean in ensureZeroMeans) + { + foreach (var norm in norms) + { + var pipe = mlContext.Transforms.NormalizeLpNorm(nameof(DataPoint.Features), norm:norm, ensureZeroMean: ensureZeroMean); + + var model = pipe.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = $"LpNorm-{norm.ToString()}-{ensureZeroMean}.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + // Compare results produced by ML.NET and ONNX's runtime. + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4VectorColumns(nameof(DataPoint.Features), outputNames[0], transformedData, onnxResult, 3); + } + } + } + + Done(); + } + [Fact] void CommandLineOnnxConversionTest() {