From 1a9e7aa88ecf572b127a72bcd3f89902d0ed337c Mon Sep 17 00:00:00 2001 From: Abhishek Goswami Date: Tue, 20 Nov 2018 14:53:48 -0800 Subject: [PATCH] Convert LdaTransform to IEstimator/ITransformer API (#1410) --- src/Microsoft.ML.Legacy/CSharpApi.cs | 24 +- .../EntryPoints/TextAnalytics.cs | 14 +- .../Text/LdaStaticExtensions.cs | 174 +++ .../Text/LdaTransform.cs | 1061 ++++++++++------- .../Text/TextStaticExtensions.cs | 52 - .../Text/WrappedTextTransformers.cs | 46 - src/Microsoft.ML.Transforms/TextCatalog.cs | 43 + .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../Common/EntryPoints/core_manifest.json | 22 +- .../UnitTests/TestEntryPoints.cs | 2 +- .../StaticPipeTests.cs | 18 +- .../DataPipe/TestDataPipe.cs | 28 +- .../Transformers/TextFeaturizerTests.cs | 42 +- 13 files changed, 931 insertions(+), 597 deletions(-) create mode 100644 src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 4a74672e70..c891b2039c 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -13997,10 +13997,10 @@ public LabelToFloatConverterPipelineStep(Output output) namespace Legacy.Transforms { - public sealed partial class LdaTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LatentDirichletAllocationTransformerColumn : OneToOneColumn, IOneToOneColumn { /// - /// The number of topics in the LDA + /// The number of topics /// public int? NumTopic { get; set; } @@ -14099,15 +14099,15 @@ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputCo public void AddColumn(string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(inputColumn)); Column = list.ToArray(); } public void AddColumn(string outputColumn, string inputColumn) { - var list = Column == null ? new List() : new List(Column); - list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); + var list = Column == null ? new List() : new List(Column); + list.Add(OneToOneColumn.Create(outputColumn, inputColumn)); Column = list.ToArray(); } @@ -14115,10 +14115,10 @@ public void AddColumn(string outputColumn, string inputColumn) /// /// New column definition(s) (optional form: name:srcs) /// - public LdaTransformColumn[] Column { get; set; } + public LatentDirichletAllocationTransformerColumn[] Column { get; set; } /// - /// The number of topics in the LDA + /// The number of topics /// [TlcModule.SweepableDiscreteParamAttribute("NumTopic", new object[]{20, 40, 100, 200})] public int NumTopic { get; set; } = 100; @@ -14153,14 +14153,14 @@ public void AddColumn(string outputColumn, string inputColumn) public int LikelihoodInterval { get; set; } = 5; /// - /// The threshold of maximum count of tokens per doc + /// The number of training threads. Default value depends on number of logical processors. /// - public int NumMaxDocToken { get; set; } = 512; + public int NumThreads { get; set; } /// - /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc /// - public int? NumThreads { get; set; } + public int NumMaxDocToken { get; set; } = 512; /// /// The number of words to summarize the topic diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs index 416565be52..b7b93584e9 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Transforms.Categorical; using Microsoft.ML.Transforms.Text; +using System.Linq; [assembly: LoadableClass(typeof(void), typeof(TextAnalytics), null, typeof(SignatureEntryPointModule), "TextAnalytics")] @@ -118,18 +119,21 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, T } [TlcModule.EntryPoint(Name = "Transforms.LightLda", - Desc = LdaTransform.Summary, - UserName = LdaTransform.UserName, - ShortName = LdaTransform.ShortName, + Desc = LatentDirichletAllocationTransformer.Summary, + UserName = LatentDirichletAllocationTransformer.UserName, + ShortName = LatentDirichletAllocationTransformer.ShortName, XmlInclude = new[] { @"", @"" })] - public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input) + public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Arguments input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); - var view = new LdaTransform(h, input, input.Data); + var cols = input.Column.Select(colPair => new LatentDirichletAllocationTransformer.ColumnInfo(colPair, input)).ToArray(); + var est = new LatentDirichletAllocationEstimator(h, cols); + var view = est.Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), diff --git a/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs new file mode 100644 index 0000000000..05acdca178 --- /dev/null +++ b/src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs @@ -0,0 +1,174 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Transforms.Text; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.StaticPipe +{ + /// + /// Information on the result of fitting a LDA transform. + /// + public sealed class LdaFitResult + { + /// + /// For user defined delegates that accept instances of the containing type. + /// + /// + public delegate void OnFit(LdaFitResult result); + + public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary; + public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary) + { + LdaTopicSummary = ldaTopicSummary; + } + } + + public static class LdaStaticExtensions + { + private struct Config + { + public readonly int NumTopic; + public readonly Single AlphaSum; + public readonly Single Beta; + public readonly int MHStep; + public readonly int NumIter; + public readonly int LikelihoodInterval; + public readonly int NumThread; + public readonly int NumMaxDocToken; + public readonly int NumSummaryTermPerTopic; + public readonly int NumBurninIter; + public readonly bool ResetRandomGenerator; + + public readonly Action OnFit; + + public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, + int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, + Action onFit) + { + NumTopic = numTopic; + AlphaSum = alphaSum; + Beta = beta; + MHStep = mhStep; + NumIter = numIter; + LikelihoodInterval = likelihoodInterval; + NumThread = numThread; + NumMaxDocToken = numMaxDocToken; + NumSummaryTermPerTopic = numSummaryTermPerTopic; + NumBurninIter = numBurninIter; + ResetRandomGenerator = resetRandomGenerator; + + OnFit = onFit; + } + } + + private static Action Wrap(LdaFitResult.OnFit onFit) + { + if (onFit == null) + return null; + + return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary)); + } + + private interface ILdaCol + { + PipelineColumn Input { get; } + Config Config { get; } + } + + private sealed class ImplVector : Vector, ILdaCol + { + public PipelineColumn Input { get; } + public Config Config { get; } + public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input) + { + Input = input; + Config = config; + } + } + + private sealed class Rec : EstimatorReconciler + { + public static readonly Rec Inst = new Rec(); + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length]; + Action onFit = null; + for (int i = 0; i < toOutput.Length; ++i) + { + var tcol = (ILdaCol)toOutput[i]; + + infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], + tcol.Config.NumTopic, + tcol.Config.AlphaSum, + tcol.Config.Beta, + tcol.Config.MHStep, + tcol.Config.NumIter, + tcol.Config.LikelihoodInterval, + tcol.Config.NumThread, + tcol.Config.NumMaxDocToken, + tcol.Config.NumSummaryTermPerTopic, + tcol.Config.NumBurninIter, + tcol.Config.ResetRandomGenerator); + + if (tcol.Config.OnFit != null) + { + int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. + onFit += tt => tcol.Config.OnFit(tt.GetLdaDetails(ii)); + } + } + + var est = new LatentDirichletAllocationEstimator(env, infos); + if (onFit == null) + return est; + + return est.WithOnFitDelegate(onFit); + } + } + + /// + /// A vector of floats representing the document. + /// The number of topics. + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. + /// Called upon fitting with the learnt enumeration on the dataset. + public static Vector ToLdaTopicVector(this Vector input, + int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, + Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, + Single beta = LatentDirichletAllocationEstimator.Defaults.Beta, + int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep, + int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations, + int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, + int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads, + int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator, + LdaFitResult.OnFit onFit = null) + { + Contracts.CheckValue(input, nameof(input)); + return new ImplVector(input, + new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, + numBurninIterations, resetRandomGenerator, Wrap(onFit))); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 0c18f713d3..3466e2219a 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -2,13 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -18,12 +12,23 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TextAnalytics; using Microsoft.ML.Transforms.Text; +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; + +[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), typeof(LatentDirichletAllocationTransformer.Arguments), typeof(SignatureDataTransform), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature, "Lda")] -[assembly: LoadableClass(typeof(LdaTransform), typeof(LdaTransform.Arguments), typeof(SignatureDataTransform), - LdaTransform.UserName, LdaTransform.LoaderSignature, LdaTransform.ShortName, DocName = "transform/LdaTransform.md")] +[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(IDataTransform), typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadDataTransform), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(LdaTransform), null, typeof(SignatureLoadDataTransform), - LdaTransform.UserName, LdaTransform.LoaderSignature)] +[assembly: LoadableClass(LatentDirichletAllocationTransformer.Summary, typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadModel), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(LatentDirichletAllocationTransformer), null, typeof(SignatureLoadRowMapper), + "Latent Dirichlet Allocation Transform", LatentDirichletAllocationTransformer.LoaderSignature)] namespace Microsoft.ML.Transforms.Text { @@ -41,60 +46,60 @@ namespace Microsoft.ML.Transforms.Text // https://github.com/Microsoft/LightLDA // // See - // for an example on how to use LdaTransform. + // for an example on how to use LatentDirichletAllocationTransformer. /// - public sealed class LdaTransform : OneToOneTransformBase + public sealed class LatentDirichletAllocationTransformer : OneToOneTransformerBase { public sealed class Arguments : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", ShortName = "col", SortOrder = 49)] public Column[] Column; - [Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics in the LDA", SortOrder = 50)] + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics", SortOrder = 50)] [TGUI(SuggestedSweeps = "20,40,100,200")] [TlcModule.SweepableDiscreteParam("NumTopic", new object[] { 20, 40, 100, 200 })] - public int NumTopic = 100; + public int NumTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on document-topic vectors")] [TGUI(SuggestedSweeps = "1,10,100,200")] [TlcModule.SweepableDiscreteParam("AlphaSum", new object[] { 1, 10, 100, 200 })] - public Single AlphaSum = 100; + public float AlphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on vocab-topic vectors")] [TGUI(SuggestedSweeps = "0.01,0.015,0.07,0.02")] [TlcModule.SweepableDiscreteParam("Beta", new object[] { 0.01f, 0.015f, 0.07f, 0.02f })] - public Single Beta = 0.01f; + public float Beta = LatentDirichletAllocationEstimator.Defaults.Beta; [Argument(ArgumentType.Multiple, HelpText = "Number of Metropolis Hasting step")] [TGUI(SuggestedSweeps = "2,4,8,16")] [TlcModule.SweepableDiscreteParam("Mhstep", new object[] { 2, 4, 8, 16 })] - public int Mhstep = 4; + public int Mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep; [Argument(ArgumentType.AtMostOnce, HelpText = "Number of iterations", ShortName = "iter")] [TGUI(SuggestedSweeps = "100,200,300,400")] [TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 100, 200, 300, 400 })] - public int NumIterations = 200; + public int NumIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Compute log likelihood over local dataset on this iteration interval", ShortName = "llInterval")] - public int LikelihoodInterval = 5; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold of maximum count of tokens per doc", ShortName = "maxNumToken", SortOrder = 50)] - public int NumMaxDocToken = 512; + public int LikelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval; // REVIEW: Should change the default when multi-threading support is optimized. [Argument(ArgumentType.AtMostOnce, HelpText = "The number of training threads. Default value depends on number of logical processors.", ShortName = "t", SortOrder = 50)] - public int? NumThreads; + public int NumThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold of maximum count of tokens per doc", ShortName = "maxNumToken", SortOrder = 50)] + public int NumMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken; [Argument(ArgumentType.AtMostOnce, HelpText = "The number of words to summarize the topic", ShortName = "ns")] - public int NumSummaryTermPerTopic = 10; + public int NumSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "The number of burn-in iterations", ShortName = "burninIter")] [TGUI(SuggestedSweeps = "10,20,30,40")] [TlcModule.SweepableDiscreteParam("NumBurninIterations", new object[] { 10, 20, 30, 40 })] - public int NumBurninIterations = 10; + public int NumBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations; [Argument(ArgumentType.AtMostOnce, HelpText = "Reset the random number generator for each document", ShortName = "reset")] - public bool ResetRandomGenerator; + public bool ResetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the topic-word summary in text format", ShortName = "summary")] public bool OutputTopicWordSummary; @@ -102,14 +107,14 @@ public sealed class Arguments : TransformInputBase public sealed class Column : OneToOneColumn { - [Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics in the LDA")] + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of topics")] public int? NumTopic; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on document-topic vectors")] - public Single? AlphaSum; + public float? AlphaSum; [Argument(ArgumentType.AtMostOnce, HelpText = "Dirichlet prior on vocab-topic vectors")] - public Single? Beta; + public float? Beta; [Argument(ArgumentType.Multiple, HelpText = "Number of Metropolis Hasting step")] public int? Mhstep; @@ -155,11 +160,13 @@ public bool TryUnparse(StringBuilder sb) } } - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; public readonly int NumTopic; - public readonly Single AlphaSum; - public readonly Single Beta; + public readonly float AlphaSum; + public readonly float Beta; public readonly int MHStep; public readonly int NumIter; public readonly int LikelihoodInterval; @@ -169,50 +176,78 @@ private sealed class ColInfoEx public readonly int NumBurninIter; public readonly bool ResetRandomGenerator; - public ColInfoEx(IExceptionContext ectx, Column item, Arguments args) + /// + /// Describes how the transformer handles one column pair. + /// + /// The column representing the document as a vector of floats. + /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. + /// The number of topics. + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. + public ColumnInfo(string input, + string output = null, + int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, + float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, + float beta = LatentDirichletAllocationEstimator.Defaults.Beta, + int mhStep = LatentDirichletAllocationEstimator.Defaults.Mhstep, + int numIter = LatentDirichletAllocationEstimator.Defaults.NumIterations, + int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, + int numThread = LatentDirichletAllocationEstimator.Defaults.NumThreads, + int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) { - Contracts.AssertValue(ectx); - - NumTopic = item.NumTopic ?? args.NumTopic; - Contracts.CheckUserArg(NumTopic > 0, nameof(item.NumTopic), "Must be positive."); - - AlphaSum = item.AlphaSum ?? args.AlphaSum; - - Beta = item.Beta ?? args.Beta; - - MHStep = item.Mhstep ?? args.Mhstep; - ectx.CheckUserArg(MHStep > 0, nameof(item.Mhstep), "Must be positive."); - - NumIter = item.NumIterations ?? args.NumIterations; - ectx.CheckUserArg(NumIter > 0, nameof(item.NumIterations), "Must be positive."); - - LikelihoodInterval = item.LikelihoodInterval ?? args.LikelihoodInterval; - ectx.CheckUserArg(LikelihoodInterval > 0, nameof(item.LikelihoodInterval), "Must be positive."); - - NumThread = item.NumThreads ?? args.NumThreads ?? 0; - ectx.CheckUserArg(NumThread >= 0, nameof(item.NumThreads), "Must be positive or zero."); - - NumMaxDocToken = item.NumMaxDocToken ?? args.NumMaxDocToken; - ectx.CheckUserArg(NumMaxDocToken > 0, nameof(item.NumMaxDocToken), "Must be positive."); - - NumSummaryTermPerTopic = item.NumSummaryTermPerTopic ?? args.NumSummaryTermPerTopic; - ectx.CheckUserArg(NumSummaryTermPerTopic > 0, nameof(item.NumSummaryTermPerTopic), "Must be positive"); - - NumBurninIter = item.NumBurninIterations ?? args.NumBurninIterations; - ectx.CheckUserArg(NumBurninIter >= 0, nameof(item.NumBurninIterations), "Must be non-negative."); + Contracts.CheckValue(input, nameof(input)); + Contracts.CheckValueOrNull(output); + Contracts.CheckParam(numTopic > 0, nameof(numTopic), "Must be positive."); + Contracts.CheckParam(mhStep > 0, nameof(mhStep), "Must be positive."); + Contracts.CheckParam(numIter > 0, nameof(numIter), "Must be positive."); + Contracts.CheckParam(likelihoodInterval > 0, nameof(likelihoodInterval), "Must be positive."); + Contracts.CheckParam(numThread >= 0, nameof(numThread), "Must be positive or zero."); + Contracts.CheckParam(numMaxDocToken > 0, nameof(numMaxDocToken), "Must be positive."); + Contracts.CheckParam(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive"); + Contracts.CheckParam(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative."); + + Input = input; + Output = output ?? input; + NumTopic = numTopic; + AlphaSum = alphaSum; + Beta = beta; + MHStep = mhStep; + NumIter = numIter; + LikelihoodInterval = likelihoodInterval; + NumThread = numThread; + NumMaxDocToken = numMaxDocToken; + NumSummaryTermPerTopic = numSummaryTermPerTopic; + NumBurninIter = numBurninIter; + ResetRandomGenerator = resetRandomGenerator; + } - ResetRandomGenerator = item.ResetRandomGenerator ?? args.ResetRandomGenerator; + internal ColumnInfo(Column item, Arguments args) : + this(item.Source, item.Name, + args.NumTopic, args.AlphaSum, args.Beta, args.Mhstep, args.NumIterations, + args.LikelihoodInterval, args.NumThreads, args.NumMaxDocToken, args.NumSummaryTermPerTopic, args.NumBurninIterations, args.ResetRandomGenerator) + { } - public ColInfoEx(IExceptionContext ectx, ModelLoadContext ctx) + internal ColumnInfo(IExceptionContext ectx, ModelLoadContext ctx) { Contracts.AssertValue(ectx); ectx.AssertValue(ctx); // *** Binary format *** // int NumTopic; - // Single AlphaSum; - // Single Beta; + // float AlphaSum; + // float Beta; // int MHStep; // int NumIter; // int LikelihoodInterval; @@ -253,14 +288,14 @@ public ColInfoEx(IExceptionContext ectx, ModelLoadContext ctx) ResetRandomGenerator = ctx.Reader.ReadBoolByte(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); // *** Binary format *** // int NumTopic; - // Single AlphaSum; - // Single Beta; + // float AlphaSum; + // float Beta; // int MHStep; // int NumIter; // int LikelihoodInterval; @@ -284,312 +319,41 @@ public void Save(ModelSaveContext ctx) } } - public const string LoaderSignature = "LdaTransform"; - private static VersionInfo GetVersionInfo() + /// + /// Provide details about the topics discovered by LightLDA. + /// + public sealed class LdaSummary { - return new VersionInfo( - modelSignature: "LIGHTLDA", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LdaTransform).Assembly.FullName); - } - - private readonly ColInfoEx[] _exes; - private readonly LdaState[] _ldas; - private readonly ColumnType[] _types; - private readonly bool _saveText; + // For each topic, provide information about the (item, score) pairs. + public readonly ImmutableArray> ItemScoresPerTopic; - private const string RegistrationName = "LightLda"; - private const string WordTopicModelFilename = "word_topic_summary.txt"; - internal const string Summary = "The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation."; - internal const string UserName = "Latent Dirichlet Allocation Transform"; - internal const string ShortName = "LightLda"; + // For each topic, provide information about the (item, word, score) tuple. + public readonly ImmutableArray> WordScoresPerTopic; - public LdaTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, args.Column, input, TestType) - { - Host.CheckValue(args, nameof(args)); - Host.CheckUserArg(args.NumTopic > 0, nameof(args.NumTopic), "Must be positive."); - Host.CheckValue(input, nameof(input)); - Host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); - _exes = new ColInfoEx[Infos.Length]; - _types = new ColumnType[Infos.Length]; - _ldas = new LdaState[Infos.Length]; - _saveText = args.OutputTopicWordSummary; - for (int i = 0; i < Infos.Length; i++) + internal LdaSummary(ImmutableArray> itemScoresPerTopic) { - var ex = new ColInfoEx(Host, args.Column[i], args); - _exes[i] = ex; - _types[i] = new VectorType(NumberType.Float, ex.NumTopic); + ItemScoresPerTopic = itemScoresPerTopic; } - using (var ch = Host.Start("Train")) - { - Train(ch, input, _ldas); - } - Metadata.Seal(); - } - private void Dispose(bool disposing) - { - if (_ldas != null) + internal LdaSummary(ImmutableArray> wordScoresPerTopic) { - foreach (var state in _ldas) - state?.Dispose(); + WordScoresPerTopic = wordScoresPerTopic; } - if (disposing) - GC.SuppressFinalize(this); - } - - public void Dispose() - { - Dispose(true); - } - - ~LdaTransform() - { - Dispose(false); - } - - private LdaTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestType) - { - Host.AssertValue(ctx); - - // *** Binary format *** - // - // - // ldaState[num infos]: The LDA parameters - - // Note: infos.length would be just one in most cases. - _exes = new ColInfoEx[Infos.Length]; - _ldas = new LdaState[Infos.Length]; - _types = new ColumnType[Infos.Length]; - for (int i = 0; i < _ldas.Length; i++) - { - _ldas[i] = new LdaState(Host, ctx); - _exes[i] = _ldas[i].InfoEx; - _types[i] = new VectorType(NumberType.Float, _ldas[i].InfoEx.NumTopic); - } - using (var ent = ctx.Repository.OpenEntryOrNull("model", WordTopicModelFilename)) - { - _saveText = ent != null; - } - Metadata.Seal(); - } - - public static LdaTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - - h.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - h.CheckValue(input, nameof(input)); - - return h.Apply( - "Loading Model", - ch => - { - // *** Binary Format *** - // int: sizeof(Float) - // - int cbFloat = ctx.Reader.ReadInt32(); - h.CheckDecode(cbFloat == sizeof(Float)); - return new LdaTransform(h, ctx, input); - }); - } - - public string GetTopicSummary() - { - StringWriter writer = new StringWriter(); - VBuffer> slotNames = default; - for (int i = 0; i < _ldas.Length; i++) - { - GetSlotNames(i, ref slotNames); - _ldas[i].GetTopicSummaryWriter(slotNames)(writer); - writer.WriteLine(); - } - return writer.ToString(); - } - - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // int: sizeof(Float) - // - // ldaState[num infos]: The LDA parameters - - ctx.Writer.Write(sizeof(Float)); - SaveBase(ctx); - Host.Assert(_ldas.Length == Infos.Length); - VBuffer> slotNames = default; - for (int i = 0; i < _ldas.Length; i++) - { - GetSlotNames(i, ref slotNames); - _ldas[i].Save(ctx, _saveText, slotNames); - } - } - - private void GetSlotNames(int iinfo, ref VBuffer> dst) - { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - if (Source.Schema.HasSlotNames(Infos[iinfo].Source, Infos[iinfo].TypeSrc.ValueCount)) - Source.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, Infos[iinfo].Source, ref dst); - else - dst = default(VBuffer>); - } - - private static string TestType(ColumnType t) - { - // LDA consumes term frequency vectors, so I am assuming VBuffer is an appropriate input type. - // It must also be of known size for the sake of the LDA trainer initialization. - if (t.IsKnownSizeVector && t.ItemType is NumberType) - return null; - return "Expected vector of number type of known size."; - } - - private static int GetFrequency(double value) - { - int result = (int)value; - if (!(result == value && result >= 0)) - return -1; - return result; } - private void Train(IChannel ch, IDataView trainingData, LdaState[] states) + internal LdaSummary GetLdaDetails(int iinfo) { - Host.AssertValue(ch); - ch.AssertValue(trainingData); - ch.AssertValue(states); - ch.Assert(states.Length == Infos.Length); - - bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; - int[] numVocabs = new int[Infos.Length]; - - for (int i = 0; i < Infos.Length; i++) - { - activeColumns[Infos[i].Source] = true; - numVocabs[i] = 0; - } - - //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, - //one for the pre-calc memory, one for feedin data really - //another solution can be prepare these two value externally and put them in the beginning of the input file. - long[] corpusSize = new long[Infos.Length]; - int[] numDocArray = new int[Infos.Length]; - - using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) - { - var getters = new ValueGetter>[Utils.Size(Infos)]; - for (int i = 0; i < Infos.Length; i++) - { - corpusSize[i] = 0; - numDocArray[i] = 0; - getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, Infos[i].Source); - } - VBuffer src = default(VBuffer); - long rowCount = 0; - - while (cursor.MoveNext()) - { - ++rowCount; - for (int i = 0; i < Infos.Length; i++) - { - int docSize = 0; - getters[i](ref src); + Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length); - // compute term, doc instance#. - var srcValues = src.GetValues(); - for (int termID = 0; termID < srcValues.Length; termID++) - { - int termFreq = GetFrequency(srcValues[termID]); - if (termFreq < 0) - { - // Ignore this row. - docSize = 0; - break; - } + var ldaState = _ldas[iinfo]; + var mapping = _columnMappings[iinfo]; - if (docSize >= _exes[i].NumMaxDocToken - termFreq) - break; //control the document length - - //if legal then add the term - docSize += termFreq; - } - - // Ignore empty doc - if (docSize == 0) - continue; - - numDocArray[i]++; - corpusSize[i] += docSize * 2 + 1; // in the beggining of each doc, there is a cursor variable - - // increase numVocab if needed. - if (numVocabs[i] < src.Length) - numVocabs[i] = src.Length; - } - } - - for (int i = 0; i < Infos.Length; ++i) - { - if (numDocArray[i] != rowCount) - { - ch.Assert(numDocArray[i] < rowCount); - ch.Warning($"Column '{Infos[i].Name}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); - } - } - } - - // Initialize all LDA states - for (int i = 0; i < Infos.Length; i++) - { - var state = new LdaState(Host, _exes[i], numVocabs[i]); - if (numDocArray[i] == 0 || corpusSize[i] == 0) - throw ch.Except("The specified documents are all empty in column '{0}'.", Infos[i].Name); - - state.AllocateDataMemory(numDocArray[i], corpusSize[i]); - states[i] = state; - } - - using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) - { - int[] docSizeCheck = new int[Infos.Length]; - // This could be optimized so that if multiple trainers consume the same column, it is - // fed into the train method once. - var getters = new ValueGetter>[Utils.Size(Infos)]; - for (int i = 0; i < Infos.Length; i++) - { - docSizeCheck[i] = 0; - getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, Infos[i].Source); - } - - VBuffer src = default(VBuffer); - - while (cursor.MoveNext()) - { - for (int i = 0; i < Infos.Length; i++) - { - getters[i](ref src); - docSizeCheck[i] += states[i].FeedTrain(Host, in src); - } - } - for (int i = 0; i < Infos.Length; i++) - { - Host.Assert(corpusSize[i] == docSizeCheck[i]); - states[i].CompleteTrain(); - } - } + return ldaState.GetLdaSummary(mapping); } private sealed class LdaState : IDisposable { - public readonly ColInfoEx InfoEx; + internal readonly ColumnInfo InfoEx; private readonly int _numVocab; private readonly object _preparationSyncRoot; private readonly object _testSyncRoot; @@ -602,7 +366,7 @@ private LdaState() _testSyncRoot = new object(); } - public LdaState(IExceptionContext ectx, ColInfoEx ex, int numVocab) + internal LdaState(IExceptionContext ectx, ColumnInfo ex, int numVocab) : this() { Contracts.AssertValue(ectx); @@ -626,7 +390,7 @@ public LdaState(IExceptionContext ectx, ColInfoEx ex, int numVocab) InfoEx.NumMaxDocToken); } - public LdaState(IExceptionContext ectx, ModelLoadContext ctx) + internal LdaState(IExceptionContext ectx, ModelLoadContext ctx) : this() { ectx.AssertValue(ctx); @@ -639,7 +403,7 @@ public LdaState(IExceptionContext ectx, ModelLoadContext ctx) // (serializing term by term, for one term) // int: term_id, int: topic_num, KeyValuePair[]: termTopicVector - InfoEx = new ColInfoEx(ectx, ctx); + InfoEx = new ColumnInfo(ectx, ctx); _numVocab = ctx.Reader.ReadInt32(); ectx.CheckDecode(_numVocab > 0); @@ -688,54 +452,52 @@ public LdaState(IExceptionContext ectx, ModelLoadContext ctx) //do the preparation if (!_predictionPreparationDone) { - _ldaTrainer.InitializeBeforeTest(); - _predictionPreparationDone = true; + lock (_preparationSyncRoot) + { + _ldaTrainer.InitializeBeforeTest(); + _predictionPreparationDone = true; + } } } - public Action GetTopicSummaryWriter(VBuffer> mapping) + internal LdaSummary GetLdaSummary(VBuffer> mapping) { - Action writeAction; - if (mapping.Length == 0) { - writeAction = - writer => + var itemScoresPerTopicBuilder = ImmutableArray.CreateBuilder>(); + for (int i = 0; i < _ldaTrainer.NumTopic; i++) + { + var scores = _ldaTrainer.GetTopicSummary(i); + var itemScores = new List<(int, float)>(); + foreach (KeyValuePair p in scores) { - for (int i = 0; i < _ldaTrainer.NumTopic; i++) - { - KeyValuePair[] topicSummaryVector = _ldaTrainer.GetTopicSummary(i); - writer.Write("{0}\t{1}\t", i, topicSummaryVector.Length); - foreach (KeyValuePair p in topicSummaryVector) - writer.Write("{0}:{1}\t", p.Key, p.Value); - writer.WriteLine(); - } - }; + itemScores.Add((p.Key, p.Value)); + } + + itemScoresPerTopicBuilder.Add(itemScores); + } + return new LdaSummary(itemScoresPerTopicBuilder.ToImmutable()); } else { - writeAction = - writer => + ReadOnlyMemory slotName = default; + var wordScoresPerTopicBuilder = ImmutableArray.CreateBuilder>(); + for (int i = 0; i < _ldaTrainer.NumTopic; i++) + { + var scores = _ldaTrainer.GetTopicSummary(i); + var wordScores = new List<(int, string, float)>(); + foreach (KeyValuePair p in scores) { - ReadOnlyMemory slotName = default; - for (int i = 0; i < _ldaTrainer.NumTopic; i++) - { - KeyValuePair[] topicSummaryVector = _ldaTrainer.GetTopicSummary(i); - writer.Write("{0}\t{1}\t", i, topicSummaryVector.Length); - foreach (KeyValuePair p in topicSummaryVector) - { - mapping.GetItemOrDefault(p.Key, ref slotName); - writer.Write("{0}[{1}]:{2}\t", p.Key, slotName, p.Value); - } - writer.WriteLine(); - } - }; + mapping.GetItemOrDefault(p.Key, ref slotName); + wordScores.Add((p.Key, slotName.ToString(), p.Value)); + } + wordScoresPerTopicBuilder.Add(wordScores); + } + return new LdaSummary(wordScoresPerTopicBuilder.ToImmutable()); } - - return writeAction; } - public void Save(ModelSaveContext ctx, bool saveText, VBuffer> mapping) + public void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); long memBlockSize = 0; @@ -770,12 +532,6 @@ public void Save(ModelSaveContext ctx, bool saveText, VBuffer src, ref VBuffer dst, int numBurninIter, bool reset) + public void Output(in VBuffer src, ref VBuffer dst, int numBurninIter, bool reset) { // Prediction for a single document. // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. @@ -871,8 +627,9 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin // It currently produces a vbuffer of all NA values. // REVIEW: Need a utility method to do this... editor = VBufferEditor.Create(ref dst, len); + for (int k = 0; k < len; k++) - editor.Values[k] = Float.NaN; + editor.Values[k] = float.NaN; dst = editor.Commit(); return; } @@ -899,7 +656,7 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin for (int i = 0; i < count; i++) { int index = retTopics[i].Key; - Float value = retTopics[i].Value; + float value = retTopics[i].Value; Contracts.Assert(value >= 0); Contracts.Assert(0 <= index && index < len); if (count < len) @@ -917,8 +674,9 @@ public void Output(in VBuffer src, ref VBuffer dst, int numBurnin if (normalizer > 0) { for (int i = 0; i < count; i++) - editor.Values[i] = (Float)(editor.Values[i] / normalizer); + editor.Values[i] = (float)(editor.Values[i] / normalizer); } + dst = editor.Commit(); } @@ -928,46 +686,481 @@ public void Dispose() } } - private ColumnType[] InitColumnTypes(int numTopics) + private sealed class Mapper : OneToOneMapperBase + { + private readonly LatentDirichletAllocationTransformer _parent; + private readonly int[] _srcCols; + + public Mapper(LatentDirichletAllocationTransformer parent, Schema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _srcCols = new int[_parent.ColumnPairs.Length]; + + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + + var srcCol = inputSchema[_srcCols[i]]; + if (!srcCol.Type.IsKnownSizeVector || !(srcCol.Type.ItemType is NumberType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input, "a fixed vector of floats", srcCol.Type.ToString()); + } + } + + protected override Schema.Column[] GetOutputColumnsCore() + { + var result = new Schema.Column[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + var info = _parent._columns[i]; + result[i] = new Schema.Column(_parent.ColumnPairs[i].output, new VectorType(NumberType.Float, info.NumTopic), null); + } + return result; + } + + protected override Delegate MakeGetter(IRow input, int iinfo, Func activeOutput, out Action disposer) + { + Contracts.AssertValue(input); + Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); + disposer = null; + + return GetTopic(input, iinfo); + } + + private ValueGetter> GetTopic(IRow input, int iinfo) + { + var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, _srcCols[iinfo]); + var src = default(VBuffer); + var lda = _parent._ldas[iinfo]; + int numBurninIter = lda.InfoEx.NumBurninIter; + bool reset = lda.InfoEx.ResetRandomGenerator; + return + (ref VBuffer dst) => + { + // REVIEW: This will work, but there are opportunities for caching + // based on input.Counter that are probably worthwhile given how long inference takes. + getSrc(ref src); + lda.Output(in src, ref dst, numBurninIter, reset); + }; + } + } + + internal const string LoaderSignature = "LdaTransform"; + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "LIGHTLDA", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LatentDirichletAllocationTransformer).Assembly.FullName); + } + + private readonly ColumnInfo[] _columns; + private readonly LdaState[] _ldas; + private readonly List>> _columnMappings; + + private const string RegistrationName = "LightLda"; + private const string WordTopicModelFilename = "word_topic_summary.txt"; + internal const string Summary = "The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation."; + internal const string UserName = "Latent Dirichlet Allocation Transform"; + internal const string ShortName = "LightLda"; + + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { - Host.Assert(Utils.Size(Infos) > 0); - var types = new ColumnType[Infos.Length]; - for (int c = 0; c < Infos.Length; c++) - types[c] = new VectorType(NumberType.Float, numTopics); - return types; + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); } - protected override ColumnType GetColumnTypeCore(int iinfo) + /// + /// Initializes a new object. + /// + /// Host Environment. + /// An array of LdaState objects, where ldas[i] is learnt from the i-th element of . + /// A list of mappings, where columnMapping[i] is a map of slot names for the i-th element of . + /// Describes the parameters of the LDA process for each column pair. + private LatentDirichletAllocationTransformer(IHostEnvironment env, + LdaState[] ldas, + List>> columnMappings, + params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LatentDirichletAllocationTransformer)), GetColumnPairs(columns)) { - Host.Assert(0 <= iinfo & iinfo < Utils.Size(_types)); - return _types[iinfo]; + Host.AssertNonEmpty(ColumnPairs); + _ldas = ldas; + _columnMappings = columnMappings; + _columns = columns; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + private LatentDirichletAllocationTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; + Host.AssertValue(ctx); + + // *** Binary format *** + // + // + // ldaState[num infos]: The LDA parameters - return GetTopic(input, iinfo); + // Note: columnsLength would be just one in most cases. + var columnsLength = ColumnPairs.Length; + _columns = new ColumnInfo[columnsLength]; + _ldas = new LdaState[columnsLength]; + for (int i = 0; i < _ldas.Length; i++) + { + _ldas[i] = new LdaState(Host, ctx); + _columns[i] = _ldas[i].InfoEx; + } + } + + internal static LatentDirichletAllocationTransformer TrainLdaTransformer(IHostEnvironment env, IDataView inputData, params ColumnInfo[] columns) + { + var ldas = new LdaState[columns.Length]; + + List>> columnMappings; + using (var ch = env.Start("Train")) + { + columnMappings = Train(env, ch, inputData, ldas, columns); + } + + return new LatentDirichletAllocationTransformer(env, ldas, columnMappings, columns); + } + + private void Dispose(bool disposing) + { + if (_ldas != null) + { + foreach (var state in _ldas) + state?.Dispose(); + } + if (disposing) + GC.SuppressFinalize(this); + } + + public void Dispose() + { + Dispose(true); + } + + ~LatentDirichletAllocationTransformer() + { + Dispose(false); + } + + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(Schema.Create(inputSchema)); + + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + env.CheckValue(args.Column, nameof(args.Column)); + + var cols = args.Column.Select(colPair => new ColumnInfo(colPair, args)).ToArray(); + return TrainLdaTransformer(env, input, cols).MakeDataTransform(input); + } + + // Factory method for SignatureLoadModel + private static LatentDirichletAllocationTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + var h = env.Register(RegistrationName); + + h.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return h.Apply( + "Loading Model", + ch => + { + // *** Binary Format *** + // int: sizeof(float) + // + int cbFloat = ctx.Reader.ReadInt32(); + h.CheckDecode(cbFloat == sizeof(float)); + return new LatentDirichletAllocationTransformer(h, ctx); + }); + } + + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // int: sizeof(float) + // + // ldaState[num infos]: The LDA parameters + + ctx.Writer.Write(sizeof(float)); + SaveColumns(ctx); + for (int i = 0; i < _ldas.Length; i++) + { + _ldas[i].Save(ctx); + } + } + + private static int GetFrequency(double value) + { + int result = (int)value; + if (!(result == value && result >= 0)) + return -1; + return result; } - private ValueGetter> GetTopic(IRow input, int iinfo) + private static List>> Train(IHostEnvironment env, IChannel ch, IDataView inputData, LdaState[] states, params ColumnInfo[] columns) { - var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.R8, input, Infos[iinfo].Source); - var src = default(VBuffer); - var lda = _ldas[iinfo]; - int numBurninIter = lda.InfoEx.NumBurninIter; - bool reset = lda.InfoEx.ResetRandomGenerator; - return - (ref VBuffer dst) => + env.AssertValue(ch); + ch.AssertValue(inputData); + ch.AssertValue(states); + ch.Assert(states.Length == columns.Length); + + bool[] activeColumns = new bool[inputData.Schema.ColumnCount]; + int[] numVocabs = new int[columns.Length]; + int[] srcCols = new int[columns.Length]; + + var columnMappings = new List>>(); + + var inputSchema = inputData.Schema; + for (int i = 0; i < columns.Length; i++) + { + if (!inputData.Schema.TryGetColumnIndex(columns[i].Input, out int srcCol)) + throw env.ExceptSchemaMismatch(nameof(inputData), "input", columns[i].Input); + + var srcColType = inputSchema.GetColumnType(srcCol); + if (!srcColType.IsKnownSizeVector || !(srcColType.ItemType is NumberType)) + throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input, "a fixed vector of floats", srcColType.ToString()); + + srcCols[i] = srcCol; + activeColumns[srcCol] = true; + numVocabs[i] = 0; + + VBuffer> dst = default; + if (inputSchema.HasSlotNames(srcCol, srcColType.ValueCount)) + inputSchema.GetMetadata(MetadataUtils.Kinds.SlotNames, srcCol, ref dst); + else + dst = default(VBuffer>); + columnMappings.Add(dst); + } + + //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, + //one for the pre-calc memory, one for feedin data really + //another solution can be prepare these two value externally and put them in the beginning of the input file. + long[] corpusSize = new long[columns.Length]; + int[] numDocArray = new int[columns.Length]; + + using (var cursor = inputData.GetRowCursor(col => activeColumns[col])) + { + var getters = new ValueGetter>[columns.Length]; + for (int i = 0; i < columns.Length; i++) { - // REVIEW: This will work, but there are opportunities for caching - // based on input.Counter that are probably worthwhile given how long inference takes. - getSrc(ref src); - lda.Output(in src, ref dst, numBurninIter, reset); - }; + corpusSize[i] = 0; + numDocArray[i] = 0; + getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); + } + VBuffer src = default(VBuffer); + long rowCount = 0; + while (cursor.MoveNext()) + { + ++rowCount; + for (int i = 0; i < columns.Length; i++) + { + int docSize = 0; + getters[i](ref src); + + // compute term, doc instance#. + var srcValues = src.GetValues(); + for (int termID = 0; termID < srcValues.Length; termID++) + { + int termFreq = GetFrequency(srcValues[termID]); + if (termFreq < 0) + { + // Ignore this row. + docSize = 0; + break; + } + + if (docSize >= columns[i].NumMaxDocToken - termFreq) + break; //control the document length + + //if legal then add the term + docSize += termFreq; + } + + // Ignore empty doc + if (docSize == 0) + continue; + + numDocArray[i]++; + corpusSize[i] += docSize * 2 + 1; // in the beggining of each doc, there is a cursor variable + + // increase numVocab if needed. + if (numVocabs[i] < src.Length) + numVocabs[i] = src.Length; + } + } + + // No data to train on, just return + if (rowCount == 0) + return columnMappings; + + for (int i = 0; i < columns.Length; ++i) + { + if (numDocArray[i] != rowCount) + { + ch.Assert(numDocArray[i] < rowCount); + ch.Warning($"Column '{columns[i].Input}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); + } + } + } + + // Initialize all LDA states + for (int i = 0; i < columns.Length; i++) + { + var state = new LdaState(env, columns[i], numVocabs[i]); + + if (numDocArray[i] == 0 || corpusSize[i] == 0) + throw ch.Except("The specified documents are all empty in column '{0}'.", columns[i].Input); + + state.AllocateDataMemory(numDocArray[i], corpusSize[i]); + states[i] = state; + } + + using (var cursor = inputData.GetRowCursor(col => activeColumns[col])) + { + int[] docSizeCheck = new int[columns.Length]; + // This could be optimized so that if multiple trainers consume the same column, it is + // fed into the train method once. + var getters = new ValueGetter>[columns.Length]; + for (int i = 0; i < columns.Length; i++) + { + docSizeCheck[i] = 0; + getters[i] = RowCursorUtils.GetVecGetterAs(NumberType.R8, cursor, srcCols[i]); + } + + VBuffer src = default(VBuffer); + + while (cursor.MoveNext()) + { + for (int i = 0; i < columns.Length; i++) + { + getters[i](ref src); + docSizeCheck[i] += states[i].FeedTrain(env, in src); + } + } + + for (int i = 0; i < columns.Length; i++) + { + env.Assert(corpusSize[i] == docSizeCheck[i]); + states[i].CompleteTrain(); + } + } + + return columnMappings; + } + + protected override IRowMapper MakeRowMapper(Schema schema) + { + return new Mapper(this, schema); + } + } + + /// + public sealed class LatentDirichletAllocationEstimator : IEstimator + { + internal static class Defaults + { + public const int NumTopic = 100; + public const float AlphaSum = 100; + public const float Beta = 0.01f; + public const int Mhstep = 4; + public const int NumIterations = 200; + public const int LikelihoodInterval = 5; + public const int NumThreads = 0; + public const int NumMaxDocToken = 512; + public const int NumSummaryTermPerTopic = 10; + public const int NumBurninIterations = 10; + public const bool ResetRandomGenerator = false; + } + + private readonly IHost _host; + private readonly ImmutableArray _columns; + + /// + /// The environment. + /// The column representing the document as a vector of floats. + /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. + /// The number of topics. + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. + public LatentDirichletAllocationEstimator(IHostEnvironment env, + string inputColumn, + string outputColumn = null, + int numTopic = Defaults.NumTopic, + float alphaSum = Defaults.AlphaSum, + float beta = Defaults.Beta, + int mhstep = Defaults.Mhstep, + int numIterations = Defaults.NumIterations, + int likelihoodInterval = Defaults.LikelihoodInterval, + int numThreads = Defaults.NumThreads, + int numMaxDocToken = Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = Defaults.NumSummaryTermPerTopic, + int numBurninIterations = Defaults.NumBurninIterations, + bool resetRandomGenerator = Defaults.ResetRandomGenerator) + : this(env, new[] { new LatentDirichletAllocationTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, + numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, + numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator) }) + { } + + /// + /// The environment. + /// Describes the parameters of the LDA process for each column pair. + public LatentDirichletAllocationEstimator(IHostEnvironment env, params LatentDirichletAllocationTransformer.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(LatentDirichletAllocationEstimator)); + _columns = columns.ToImmutableArray(); + } + + /// + /// Returns the schema that would be produced by the transformation. + /// + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (col.ItemType.RawKind != DataKind.R4 || col.Kind == SchemaShape.Column.VectorKind.Scalar) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, "a vector of floats", col.GetTypeString()); + + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + + return new SchemaShape(result.Values); + } + + public LatentDirichletAllocationTransformer Fit(IDataView input) + { + return LatentDirichletAllocationTransformer.TrainLdaTransformer(_host, input, _columns.ToArray()); } } } diff --git a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs index 16d84c23e3..83b05cbbb7 100644 --- a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs @@ -591,56 +591,4 @@ public static Vector ToNgramsHash(this VarVector> input bool ordered = true, int invertHash = 0) => new OutPipelineColumn(input, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); } - - /// - /// Extensions for statically typed . - /// - public static class LdaEstimatorExtensions - { - private sealed class OutPipelineColumn : Vector - { - public readonly Vector Input; - - public OutPipelineColumn(Vector input, int numTopic, Action advancedSettings) - : base(new Reconciler(numTopic, advancedSettings), input) - { - Input = input; - } - } - - private sealed class Reconciler : EstimatorReconciler - { - private readonly int _numTopic; - private readonly Action _advancedSettings; - - public Reconciler(int numTopic, Action advancedSettings) - { - _numTopic = numTopic; - _advancedSettings = advancedSettings; - } - - public override IEstimator Reconcile(IHostEnvironment env, - PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, - IReadOnlyDictionary outputNames, - IReadOnlyCollection usedNames) - { - Contracts.Assert(toOutput.Length == 1); - - var pairs = new List<(string input, string output)>(); - foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); - - return new LdaEstimator(env, pairs.ToArray(), _numTopic, _advancedSettings); - } - } - - /// - /// The column to apply to. - /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. - public static Vector ToLdaTopicVector(this Vector input, - int numTopic = 100, - Action advancedSettings = null) => new OutPipelineColumn(input, numTopic, advancedSettings); - } } diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index 97fa3ddedb..e5fd8f8cce 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -452,50 +452,4 @@ public override TransformWrapper Fit(IDataView input) return new TransformWrapper(Host, new NgramHashingTransformer(Host, args, input)); } } - - /// - public sealed class LdaEstimator : TrainedWrapperEstimatorBase - { - private readonly LdaTransform.Arguments _args; - - /// - /// The environment. - /// The column containing text to tokenize. - /// The column containing output tokens. Null means is replaced. - /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. - public LdaEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, - int numTopic = 100, - Action advancedSettings = null) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, - numTopic, - advancedSettings) - { - } - - /// - /// The environment. - /// Pairs of columns to compute LDA. - /// The number of topics in the LDA. - /// A delegate to apply all the advanced arguments to the algorithm. - public LdaEstimator(IHostEnvironment env, - (string input, string output)[] columns, - int numTopic = 100, - Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LdaEstimator))) - { - _args = new LdaTransform.Arguments(); - _args.Column = columns.Select(x => new LdaTransform.Column { Source = x.input, Name = x.output }).ToArray(); - _args.NumTopic = numTopic; - - advancedSettings?.Invoke(_args); - } - - public override TransformWrapper Fit(IDataView input) - { - return new TransformWrapper(Host, new LdaTransform(Host, _args, input)); - } - } } \ No newline at end of file diff --git a/src/Microsoft.ML.Transforms/TextCatalog.cs b/src/Microsoft.ML.Transforms/TextCatalog.cs index a052fea859..c3b01bfe0b 100644 --- a/src/Microsoft.ML.Transforms/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/TextCatalog.cs @@ -230,5 +230,48 @@ public static NgramCountingEstimator ProduceNgrams(this TransformsCatalog.TextTr params NgramCountingTransformer.ColumnInfo[] columns) => new NgramCountingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); + /// + /// Uses LightLDA to transform a document (represented as a vector of floats) + /// into a vector of floats over a set of topics. + /// + /// The transform's catalog. + /// The column representing the document as a vector of floats. + /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. + /// The number of topics. + /// Dirichlet prior on document-topic vectors. + /// Dirichlet prior on vocab-topic vectors. + /// Number of Metropolis Hasting step. + /// Number of iterations. + /// Compute log likelihood over local dataset on this iteration interval. + /// The number of training threads. Default value depends on number of logical processors. + /// The threshold of maximum count of tokens per doc. + /// The number of words to summarize the topic. + /// The number of burn-in iterations. + /// Reset the random number generator for each document. + public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this TransformsCatalog.TextTransforms catalog, + string inputColumn, + string outputColumn = null, + int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, + float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, + float beta = LatentDirichletAllocationEstimator.Defaults.Beta, + int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep, + int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations, + int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, + int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads, + int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, + int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, + int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, + bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) + => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, + numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator); + + /// + /// Uses LightLDA to transform a document (represented as a vector of floats) + /// into a vector of floats over a set of topics. + /// + /// The transform's catalog. + /// Describes the parameters of LDA for each column pair. + public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this TransformsCatalog.TextTransforms catalog, params LatentDirichletAllocationTransformer.ColumnInfo[] columns) + => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index af6b0eef0d..655b617d72 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -104,7 +104,7 @@ Transforms.KeyToTextConverter KeyToValueTransform utilizes KeyValues metadata to Transforms.LabelColumnKeyBooleanConverter Transforms the label to either key or bool (if needed) to make it suitable for classification. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner PrepareClassificationLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+ClassificationLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelIndicator Label remapper used by OVA Microsoft.ML.Transforms.LabelIndicatorTransform LabelIndicator Microsoft.ML.Transforms.LabelIndicatorTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LabelToFloatConverter Transforms the label to float to make it suitable for regression. Microsoft.ML.Runtime.EntryPoints.FeatureCombiner PrepareRegressionLabel Microsoft.ML.Runtime.EntryPoints.FeatureCombiner+RegressionLabelInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput -Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LdaTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput +Transforms.LightLda The LDA transform implements LightLDA, a state-of-the-art implementation of Latent Dirichlet Allocation. Microsoft.ML.Transforms.Text.TextAnalytics LightLda Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LogMeanVarianceNormalizer Normalizes the data based on the computed mean and variance of the logarithm of the data. Microsoft.ML.Runtime.Data.Normalize LogMeanVar Microsoft.ML.Transforms.Normalizers.NormalizeTransform+LogMeanVarArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.LpNormalizer Normalize vectors (rows) individually by rescaling them to unit norm (L2, L1 or LInf). Performs the following operation on a vector X: Y = (X - M) / D, where M is mean and D is either L2 norm, L1 norm or LInf norm. Microsoft.ML.Transforms.Projections.LpNormalization Normalize Microsoft.ML.Transforms.Projections.LpNormalizingTransformer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.ManyHeterogeneousModelCombiner Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel. Microsoft.ML.Runtime.EntryPoints.ModelOperations CombineModels Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index bfdec78010..6b5b8eb81a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -20054,7 +20054,7 @@ { "Name": "NumTopic", "Type": "Int", - "Desc": "The number of topics in the LDA", + "Desc": "The number of topics", "Required": false, "SortOrder": 150.0, "IsNullable": true, @@ -20209,7 +20209,7 @@ { "Name": "NumTopic", "Type": "Int", - "Desc": "The number of topics in the LDA", + "Desc": "The number of topics", "Required": false, "SortOrder": 50.0, "IsNullable": false, @@ -20225,28 +20225,28 @@ } }, { - "Name": "NumMaxDocToken", + "Name": "NumThreads", "Type": "Int", - "Desc": "The threshold of maximum count of tokens per doc", + "Desc": "The number of training threads. Default value depends on number of logical processors.", "Aliases": [ - "maxNumToken" + "t" ], "Required": false, "SortOrder": 50.0, "IsNullable": false, - "Default": 512 + "Default": 0 }, { - "Name": "NumThreads", + "Name": "NumMaxDocToken", "Type": "Int", - "Desc": "The number of training threads. Default value depends on number of logical processors.", + "Desc": "The threshold of maximum count of tokens per doc", "Aliases": [ - "t" + "maxNumToken" ], "Required": false, "SortOrder": 50.0, - "IsNullable": true, - "Default": null + "IsNullable": false, + "Default": 512 }, { "Name": "AlphaSum", diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index ca30ecb500..386f9f8a97 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2020,7 +2020,7 @@ public void EntryPointPcaTransform() } [Fact] - public void EntryPointLightLdaTransform() + public void EntryPointLightLdaTransformer() { string dataFile = DeleteOutputPath("SavePipe", "SavePipeTextLightLda-SampleText.txt"); File.WriteAllLines(dataFile, new[] { diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index ab7049a58a..4de9cfc660 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -22,6 +22,7 @@ using System.Text; using Xunit; using Xunit.Abstractions; +using static Microsoft.ML.Transforms.Text.LatentDirichletAllocationTransformer; namespace Microsoft.ML.StaticPipelineTesting { @@ -666,7 +667,7 @@ public void LpGcNormAndWhitening() Assert.True(type is VectorType vecType4 && vecType4.Size > 0 && vecType4.ItemType is NumberType); } - [Fact(Skip = "LDA transform cannot be trained on empty data, schema propagation fails")] + [Fact] public void LdaTopicModel() { var env = new MLContext(0); @@ -677,21 +678,22 @@ public void LdaTopicModel() var dataSource = new MultiFileSource(dataPath); var data = reader.Read(dataSource); + // This will be populated once we call fit. + LdaSummary ldaSummary; + var est = data.MakeNewEstimator() .Append(r => ( r.label, - topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 10, advancedSettings: s => - { - s.AlphaSum = 10; - }))); + topics: r.text.ToBagofWords().ToLdaTopicVector(numTopic: 3, numSummaryTermPerTopic:5, alphaSum: 10, onFit: m => ldaSummary = m.LdaTopicSummary))); - var tdata = est.Fit(data).Transform(data); - var schema = tdata.AsDynamic.Schema; + var transformer = est.Fit(data); + var tdata = transformer.Transform(data); + var schema = tdata.AsDynamic.Schema; Assert.True(schema.TryGetColumnIndex("topics", out int topicsCol)); var type = schema.GetColumnType(topicsCol); Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType); -} + } [Fact(Skip = "FeatureSeclection transform cannot be trained on empty data, schema propagation fails")] public void FeatureSelection() diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index a9b6d451e2..d2b575a98b 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -913,23 +913,13 @@ public void TestLDATransform() }; builder.AddColumn("F1V", NumberType.Float, data); - var srcView = builder.GetDataView(); - LdaTransform.Column col = new LdaTransform.Column(); - col.Source = "F1V"; - col.NumTopic = 20; - col.NumTopic = 3; - col.NumSummaryTermPerTopic = 3; - col.AlphaSum = 3; - col.NumThreads = 1; - col.ResetRandomGenerator = true; - LdaTransform.Arguments args = new LdaTransform.Arguments(); - args.Column = new LdaTransform.Column[] { col }; - - LdaTransform ldaTransform = new LdaTransform(Env, args, srcView); + var est = new LatentDirichletAllocationEstimator(Env, "F1V", numTopic: 3, numSummaryTermPerTopic: 3, alphaSum: 3, numThreads: 1, resetRandomGenerator: true); + var ldaTransformer = est.Fit(srcView); + var transformedData = ldaTransformer.Transform(srcView); - using (var cursor = ldaTransform.GetRowCursor(c => true)) + using (var cursor = transformedData.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter>(1); VBuffer resultFirstRow = new VBuffer(); @@ -960,7 +950,7 @@ public void TestLDATransform() } [Fact] - public void TestLdaTransformEmptyDocumentException() + public void TestLdaTransformerEmptyDocumentException() { var builder = new ArrayDataViewBuilder(Env); var data = new[] @@ -973,18 +963,18 @@ public void TestLdaTransformEmptyDocumentException() builder.AddColumn("Zeros", NumberType.Float, data); var srcView = builder.GetDataView(); - var col = new LdaTransform.Column() + var col = new LatentDirichletAllocationTransformer.Column() { - Source = "Zeros" + Source = "Zeros", }; - var args = new LdaTransform.Arguments() + var args = new LatentDirichletAllocationTransformer.Arguments() { Column = new[] { col } }; try { - var lda = new LdaTransform(Env, args, srcView); + var lda = new LatentDirichletAllocationEstimator(Env, "Zeros").Fit(srcView).Transform(srcView); } catch (InvalidOperationException ex) { diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 5eeadfaf07..d3d66d5751 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -6,10 +6,11 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Transforms; -using Microsoft.ML.Transforms.Text; using Microsoft.ML.Transforms.Categorical; using Microsoft.ML.Transforms.Conversions; +using Microsoft.ML.Transforms.Text; using System.IO; using Xunit; using Xunit.Abstractions; @@ -235,7 +236,7 @@ public void NgramWorkout() Done(); } - [Fact(Skip = "LDA transform cannot be trained on empty data, schema propagation fails")] + [Fact] public void LdaWorkout() { IHostEnvironment env = new MLContext(seed: 42, conc: 1); @@ -251,21 +252,21 @@ public void LdaWorkout() .Read(sentimentDataPath); var est = new WordBagEstimator(env, "text", "bag_of_words"). - Append(new LdaEstimator(env, "bag_of_words", "topics", 10, advancedSettings: s => - { - s.NumIterations = 10; - s.ResetRandomGenerator = true; - })); + Append(new LatentDirichletAllocationEstimator(env, "bag_of_words", "topics", 10, numIterations: 10, + resetRandomGenerator: true)); // The following call fails because of the following issue // https://github.com/dotnet/machinelearning/issues/969 + // In this test it manifests because of the WordBagEstimator in the estimator chain // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); var outputPath = GetOutputPath("Text", "ldatopics.tsv"); using (var ch = env.Start("save")) { var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false, Dense = true }); - IDataView savedData = TakeFilter.Create(env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); + var transformer = est.Fit(data.AsDynamic); + var transformedData = transformer.Transform(data.AsDynamic); + IDataView savedData = TakeFilter.Create(env, transformedData, 4); savedData = ColumnSelectingTransformer.CreateKeep(env, savedData, new[] { "topics" }); using (var fs = File.Create(outputPath)) @@ -281,5 +282,30 @@ public void LdaWorkout() // CheckEquality("Text", "ldatopics.tsv"); Done(); } + + [Fact] + public void LdaWorkoutEstimatorCore() + { + var ml = new MLContext(); + + var builder = new ArrayDataViewBuilder(Env); + var data = new[] + { + new[] { (float)1.0, (float)0.0, (float)0.0 }, + new[] { (float)0.0, (float)1.0, (float)0.0 }, + new[] { (float)0.0, (float)0.0, (float)1.0 }, + }; + builder.AddColumn("F1V", NumberType.Float, data); + var srcView = builder.GetDataView(); + + var est = ml.Transforms.Text.LatentDirichletAllocation("F1V"); + TestEstimatorCore(est, srcView); + } + + [Fact] + public void TestLdaCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0-10} xf=lda{col=B:A} in=f:\2.txt" }), (int)0); + } } }