Skip to content

Commit

Permalink
Convert LdaTransform to IEstimator/ITransformer API (#1410)
Browse files Browse the repository at this point in the history
  • Loading branch information
abgoswam authored and shauheen committed Nov 20, 2018
1 parent dafa30c commit 1a9e7aa
Show file tree
Hide file tree
Showing 13 changed files with 931 additions and 597 deletions.
24 changes: 12 additions & 12 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13997,10 +13997,10 @@ public LabelToFloatConverterPipelineStep(Output output)
namespace Legacy.Transforms
{

public sealed partial class LdaTransformColumn : OneToOneColumn<LdaTransformColumn>, IOneToOneColumn
public sealed partial class LatentDirichletAllocationTransformerColumn : OneToOneColumn<LatentDirichletAllocationTransformerColumn>, IOneToOneColumn
{
/// <summary>
/// The number of topics in the LDA
/// The number of topics
/// </summary>
public int? NumTopic { get; set; }

Expand Down Expand Up @@ -14099,26 +14099,26 @@ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputCo

public void AddColumn(string inputColumn)
{
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>() : new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>.Create(inputColumn));
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>.Create(inputColumn));
Column = list.ToArray();
}

public void AddColumn(string outputColumn, string inputColumn)
{
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>() : new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>.Create(outputColumn, inputColumn));
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>.Create(outputColumn, inputColumn));
Column = list.ToArray();
}


/// <summary>
/// New column definition(s) (optional form: name:srcs)
/// </summary>
public LdaTransformColumn[] Column { get; set; }
public LatentDirichletAllocationTransformerColumn[] Column { get; set; }

/// <summary>
/// The number of topics in the LDA
/// The number of topics
/// </summary>
[TlcModule.SweepableDiscreteParamAttribute("NumTopic", new object[]{20, 40, 100, 200})]
public int NumTopic { get; set; } = 100;
Expand Down Expand Up @@ -14153,14 +14153,14 @@ public void AddColumn(string outputColumn, string inputColumn)
public int LikelihoodInterval { get; set; } = 5;

/// <summary>
/// The threshold of maximum count of tokens per doc
/// The number of training threads. Default value depends on number of logical processors.
/// </summary>
public int NumMaxDocToken { get; set; } = 512;
public int NumThreads { get; set; }

/// <summary>
/// The number of training threads. Default value depends on number of logical processors.
/// The threshold of maximum count of tokens per doc
/// </summary>
public int? NumThreads { get; set; }
public int NumMaxDocToken { get; set; } = 512;

/// <summary>
/// The number of words to summarize the topic
Expand Down
14 changes: 9 additions & 5 deletions src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Transforms.Categorical;
using Microsoft.ML.Transforms.Text;
using System.Linq;

[assembly: LoadableClass(typeof(void), typeof(TextAnalytics), null, typeof(SignatureEntryPointModule), "TextAnalytics")]

Expand Down Expand Up @@ -118,18 +119,21 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, T
}

[TlcModule.EntryPoint(Name = "Transforms.LightLda",
Desc = LdaTransform.Summary,
UserName = LdaTransform.UserName,
ShortName = LdaTransform.ShortName,
Desc = LatentDirichletAllocationTransformer.Summary,
UserName = LatentDirichletAllocationTransformer.UserName,
ShortName = LatentDirichletAllocationTransformer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/member[@name=""LightLDA""]/*' />",
@"<include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/example[@name=""LightLDA""]/*' />" })]
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input)
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));

var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input);
var view = new LdaTransform(h, input, input.Data);
var cols = input.Column.Select(colPair => new LatentDirichletAllocationTransformer.ColumnInfo(colPair, input)).ToArray();
var est = new LatentDirichletAllocationEstimator(h, cols);
var view = est.Fit(input.Data).Transform(input.Data);

return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, view, input.Data),
Expand Down
174 changes: 174 additions & 0 deletions src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.StaticPipe.Runtime;
using Microsoft.ML.Transforms.Text;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.StaticPipe
{
/// <summary>
/// Information on the result of fitting a LDA transform.
/// </summary>
public sealed class LdaFitResult
{
/// <summary>
/// For user defined delegates that accept instances of the containing type.
/// </summary>
/// <param name="result"></param>
public delegate void OnFit(LdaFitResult result);

public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary;
public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary)
{
LdaTopicSummary = ldaTopicSummary;
}
}

public static class LdaStaticExtensions
{
private struct Config
{
public readonly int NumTopic;
public readonly Single AlphaSum;
public readonly Single Beta;
public readonly int MHStep;
public readonly int NumIter;
public readonly int LikelihoodInterval;
public readonly int NumThread;
public readonly int NumMaxDocToken;
public readonly int NumSummaryTermPerTopic;
public readonly int NumBurninIter;
public readonly bool ResetRandomGenerator;

public readonly Action<LatentDirichletAllocationTransformer.LdaSummary> OnFit;

public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval,
int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator,
Action<LatentDirichletAllocationTransformer.LdaSummary> onFit)
{
NumTopic = numTopic;
AlphaSum = alphaSum;
Beta = beta;
MHStep = mhStep;
NumIter = numIter;
LikelihoodInterval = likelihoodInterval;
NumThread = numThread;
NumMaxDocToken = numMaxDocToken;
NumSummaryTermPerTopic = numSummaryTermPerTopic;
NumBurninIter = numBurninIter;
ResetRandomGenerator = resetRandomGenerator;

OnFit = onFit;
}
}

private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LdaFitResult.OnFit onFit)
{
if (onFit == null)
return null;

return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary));
}

private interface ILdaCol
{
PipelineColumn Input { get; }
Config Config { get; }
}

private sealed class ImplVector : Vector<float>, ILdaCol
{
public PipelineColumn Input { get; }
public Config Config { get; }
public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input)
{
Input = input;
Config = config;
}
}

private sealed class Rec : EstimatorReconciler
{
public static readonly Rec Inst = new Rec();

public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
PipelineColumn[] toOutput,
IReadOnlyDictionary<PipelineColumn, string> inputNames,
IReadOnlyDictionary<PipelineColumn, string> outputNames,
IReadOnlyCollection<string> usedNames)
{
var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length];
Action<LatentDirichletAllocationTransformer> onFit = null;
for (int i = 0; i < toOutput.Length; ++i)
{
var tcol = (ILdaCol)toOutput[i];

infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]],
tcol.Config.NumTopic,
tcol.Config.AlphaSum,
tcol.Config.Beta,
tcol.Config.MHStep,
tcol.Config.NumIter,
tcol.Config.LikelihoodInterval,
tcol.Config.NumThread,
tcol.Config.NumMaxDocToken,
tcol.Config.NumSummaryTermPerTopic,
tcol.Config.NumBurninIter,
tcol.Config.ResetRandomGenerator);

if (tcol.Config.OnFit != null)
{
int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call.
onFit += tt => tcol.Config.OnFit(tt.GetLdaDetails(ii));
}
}

var est = new LatentDirichletAllocationEstimator(env, infos);
if (onFit == null)
return est;

return est.WithOnFitDelegate(onFit);
}
}

/// <include file='doc.xml' path='doc/members/member[@name="LightLDA"]/*' />
/// <param name="input">A vector of floats representing the document.</param>
/// <param name="numTopic">The number of topics.</param>
/// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param>
/// <param name="beta">Dirichlet prior on vocab-topic vectors.</param>
/// <param name="mhstep">Number of Metropolis Hasting step.</param>
/// <param name="numIterations">Number of iterations.</param>
/// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param>
/// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param>
/// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param>
/// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param>
/// <param name="numBurninIterations">The number of burn-in iterations.</param>
/// <param name="resetRandomGenerator">Reset the random number generator for each document.</param>
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
public static Vector<float> ToLdaTopicVector(this Vector<float> input,
int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic,
Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum,
Single beta = LatentDirichletAllocationEstimator.Defaults.Beta,
int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep,
int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations,
int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval,
int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads,
int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken,
int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic,
int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations,
bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator,
LdaFitResult.OnFit onFit = null)
{
Contracts.CheckValue(input, nameof(input));
return new ImplVector(input,
new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic,
numBurninIterations, resetRandomGenerator, Wrap(onFit)));
}
}
}
Loading

0 comments on commit 1a9e7aa

Please sign in to comment.