Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert LdaTransform to IEstimator/ITransformer API #1410

Merged
merged 38 commits into from
Nov 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
acd8964
preparing to convert LDATransform to the IEstimator/ITransformer para…
abgoswam Oct 24, 2018
cd2f20c
Added ColumnInfo and TransformInfo
abgoswam Oct 24, 2018
c289ed1
existing tests pass after refactor
abgoswam Oct 25, 2018
8dfb527
fix re-arragement of arguments in entrypoints
abgoswam Oct 25, 2018
bcb3b0d
refactor + cleanup of TopicSummary
abgoswam Oct 25, 2018
1d39408
enabled OnFit to return LdaState
abgoswam Oct 27, 2018
9ff60f7
fixed merge conflicts
abgoswam Nov 1, 2018
ad36d2f
fix build issues after merge; fix review comments
abgoswam Nov 1, 2018
b0e0375
taking care of review comments - 1
abgoswam Nov 10, 2018
b0422e4
Merge branch 'master' into abgoswam/LDA_conversion
abgoswam Nov 10, 2018
7bc6e2b
merge with master; re-enable LDA tests before taking care of addition…
abgoswam Nov 10, 2018
e42c5e4
review comments - 2. rename LdaTransform to LdaTransformer
abgoswam Nov 10, 2018
c099d4a
review comments - 3. throw ExceptSchemaMismatch; default values; ToIm…
abgoswam Nov 11, 2018
a1d14ed
review comments - 4. output column; expression body definition
abgoswam Nov 12, 2018
3f39a04
review comments - 4; added a basic test that exercises TestEstimatorC…
abgoswam Nov 12, 2018
57cd1c5
review comments - 5; rename _exes to _columns; preparing changes for …
abgoswam Nov 12, 2018
d4a4283
review comments - 6; make training a private static method.
abgoswam Nov 12, 2018
c7fb50a
review comments - 7; make Create() method private
abgoswam Nov 12, 2018
d7660ca
review comments - 8; added internal constructor for ColumnInfo()
abgoswam Nov 12, 2018
e0d501b
review comments - 9; fixed types for Single, Float
abgoswam Nov 12, 2018
c91afbb
review comments - 10; made LdaState internal, expose LDA summary info…
abgoswam Nov 13, 2018
4238fa1
review comments - 11; added a command line unit test
abgoswam Nov 13, 2018
34bb2e9
review comments - 12; no-op when there is no data (maml command line …
abgoswam Nov 13, 2018
8b70ab1
review comments - 13; added mlcontext extension for LDA
abgoswam Nov 14, 2018
b3c1284
Merge branch 'master' into abgoswam/LDA_conversion
abgoswam Nov 14, 2018
b6e4028
review comments - 14; code refactor; avoid using abbreviation(Lda); r…
abgoswam Nov 15, 2018
5397de5
review comments - 15; schema changes
abgoswam Nov 15, 2018
edd60af
review comments - 15; provide better user-facing description. renamed…
abgoswam Nov 15, 2018
b073038
review comments - 16; fix build break
abgoswam Nov 15, 2018
49da3ee
review comments - 16; include words in LdaSummary (this also resolves…
abgoswam Nov 16, 2018
0724290
fixing merge conflicts
abgoswam Nov 16, 2018
5073baa
fixing merge conflicts
abgoswam Nov 17, 2018
d1481f8
fix build break because of manifest changes
abgoswam Nov 17, 2018
65125d4
Merge branch 'master' into abgoswam/LDA_conversion
abgoswam Nov 20, 2018
b869d7f
updated to latest interface changes
abgoswam Nov 20, 2018
62955a8
review comments - 17; named tuple, namespace change, fix CheckParams
abgoswam Nov 20, 2018
40333a7
review comments - 18; ImmutableArray
abgoswam Nov 20, 2018
850856b
review comments - 19; update summary text; remove dup code in constru…
abgoswam Nov 20, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13997,10 +13997,10 @@ public LabelToFloatConverterPipelineStep(Output output)
namespace Legacy.Transforms
{

public sealed partial class LdaTransformColumn : OneToOneColumn<LdaTransformColumn>, IOneToOneColumn
public sealed partial class LatentDirichletAllocationTransformerColumn : OneToOneColumn<LatentDirichletAllocationTransformerColumn>, IOneToOneColumn
{
/// <summary>
/// The number of topics in the LDA
/// The number of topics
/// </summary>
public int? NumTopic { get; set; }

Expand Down Expand Up @@ -14099,26 +14099,26 @@ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputCo

public void AddColumn(string inputColumn)
{
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>() : new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>.Create(inputColumn));
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>.Create(inputColumn));
Column = list.ToArray();
}

public void AddColumn(string outputColumn, string inputColumn)
{
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>() : new List<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LdaTransformColumn>.Create(outputColumn, inputColumn));
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>(Column);
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.LatentDirichletAllocationTransformerColumn>.Create(outputColumn, inputColumn));
Column = list.ToArray();
}


/// <summary>
/// New column definition(s) (optional form: name:srcs)
/// </summary>
public LdaTransformColumn[] Column { get; set; }
public LatentDirichletAllocationTransformerColumn[] Column { get; set; }

/// <summary>
/// The number of topics in the LDA
/// The number of topics
/// </summary>
[TlcModule.SweepableDiscreteParamAttribute("NumTopic", new object[]{20, 40, 100, 200})]
public int NumTopic { get; set; } = 100;
Expand Down Expand Up @@ -14153,14 +14153,14 @@ public void AddColumn(string outputColumn, string inputColumn)
public int LikelihoodInterval { get; set; } = 5;

/// <summary>
/// The threshold of maximum count of tokens per doc
/// The number of training threads. Default value depends on number of logical processors.
/// </summary>
public int NumMaxDocToken { get; set; } = 512;
public int NumThreads { get; set; }

/// <summary>
/// The number of training threads. Default value depends on number of logical processors.
/// The threshold of maximum count of tokens per doc
/// </summary>
public int? NumThreads { get; set; }
public int NumMaxDocToken { get; set; } = 512;

/// <summary>
/// The number of words to summarize the topic
Expand Down
14 changes: 9 additions & 5 deletions src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Transforms.Categorical;
using Microsoft.ML.Transforms.Text;
using System.Linq;

[assembly: LoadableClass(typeof(void), typeof(TextAnalytics), null, typeof(SignatureEntryPointModule), "TextAnalytics")]

Expand Down Expand Up @@ -118,18 +119,21 @@ public static CommonOutputs.TransformOutput CharTokenize(IHostEnvironment env, T
}

[TlcModule.EntryPoint(Name = "Transforms.LightLda",
Desc = LdaTransform.Summary,
UserName = LdaTransform.UserName,
ShortName = LdaTransform.ShortName,
Desc = LatentDirichletAllocationTransformer.Summary,
UserName = LatentDirichletAllocationTransformer.UserName,
ShortName = LatentDirichletAllocationTransformer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/member[@name=""LightLDA""]/*' />",
@"<include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/example[@name=""LightLDA""]/*' />" })]
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input)
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LatentDirichletAllocationTransformer.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));

var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input);
var view = new LdaTransform(h, input, input.Data);
var cols = input.Column.Select(colPair => new LatentDirichletAllocationTransformer.ColumnInfo(colPair, input)).ToArray();
var est = new LatentDirichletAllocationEstimator(h, cols);
var view = est.Fit(input.Data).Transform(input.Data);

return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, view, input.Data),
Expand Down
174 changes: 174 additions & 0 deletions src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.StaticPipe.Runtime;
using Microsoft.ML.Transforms.Text;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.StaticPipe
{
/// <summary>
/// Information on the result of fitting a LDA transform.
/// </summary>
public sealed class LdaFitResult
{
/// <summary>
/// For user defined delegates that accept instances of the containing type.
/// </summary>
/// <param name="result"></param>
public delegate void OnFit(LdaFitResult result);

public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary;
public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary)
{
LdaTopicSummary = ldaTopicSummary;
}
}

public static class LdaStaticExtensions
{
private struct Config
{
public readonly int NumTopic;
public readonly Single AlphaSum;
Copy link
Member

@wschin wschin Nov 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public readonly Single AlphaSum;
public readonly float AlphaSum;
``` #Resolved

public readonly Single Beta;
Copy link
Member

@wschin wschin Nov 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public readonly Single Beta;
public readonly float Beta;
``` #Resolved

public readonly int MHStep;
public readonly int NumIter;
public readonly int LikelihoodInterval;
public readonly int NumThread;
public readonly int NumMaxDocToken;
public readonly int NumSummaryTermPerTopic;
public readonly int NumBurninIter;
public readonly bool ResetRandomGenerator;

public readonly Action<LatentDirichletAllocationTransformer.LdaSummary> OnFit;

public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval,
int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator,
Action<LatentDirichletAllocationTransformer.LdaSummary> onFit)
{
NumTopic = numTopic;
AlphaSum = alphaSum;
Beta = beta;
MHStep = mhStep;
NumIter = numIter;
LikelihoodInterval = likelihoodInterval;
NumThread = numThread;
NumMaxDocToken = numMaxDocToken;
NumSummaryTermPerTopic = numSummaryTermPerTopic;
NumBurninIter = numBurninIter;
ResetRandomGenerator = resetRandomGenerator;

OnFit = onFit;
}
}

private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LdaFitResult.OnFit onFit)
{
if (onFit == null)
return null;

return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary));
}

private interface ILdaCol
{
PipelineColumn Input { get; }
Config Config { get; }
}

private sealed class ImplVector : Vector<float>, ILdaCol
{
public PipelineColumn Input { get; }
public Config Config { get; }
public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input)
{
Input = input;
Config = config;
}
}

private sealed class Rec : EstimatorReconciler
{
public static readonly Rec Inst = new Rec();

public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
PipelineColumn[] toOutput,
IReadOnlyDictionary<PipelineColumn, string> inputNames,
IReadOnlyDictionary<PipelineColumn, string> outputNames,
IReadOnlyCollection<string> usedNames)
{
var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length];
Action<LatentDirichletAllocationTransformer> onFit = null;
for (int i = 0; i < toOutput.Length; ++i)
{
var tcol = (ILdaCol)toOutput[i];

infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]],
tcol.Config.NumTopic,
tcol.Config.AlphaSum,
tcol.Config.Beta,
tcol.Config.MHStep,
tcol.Config.NumIter,
tcol.Config.LikelihoodInterval,
tcol.Config.NumThread,
tcol.Config.NumMaxDocToken,
tcol.Config.NumSummaryTermPerTopic,
tcol.Config.NumBurninIter,
tcol.Config.ResetRandomGenerator);

if (tcol.Config.OnFit != null)
{
int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call.
onFit += tt => tcol.Config.OnFit(tt.GetLdaDetails(ii));
}
}

var est = new LatentDirichletAllocationEstimator(env, infos);
if (onFit == null)
return est;

return est.WithOnFitDelegate(onFit);
}
}

/// <include file='doc.xml' path='doc/members/member[@name="LightLDA"]/*' />
/// <param name="input">A vector of floats representing the document.</param>
/// <param name="numTopic">The number of topics.</param>
/// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param>
/// <param name="beta">Dirichlet prior on vocab-topic vectors.</param>
/// <param name="mhstep">Number of Metropolis Hasting step.</param>
/// <param name="numIterations">Number of iterations.</param>
/// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param>
/// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param>
/// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param>
/// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param>
/// <param name="numBurninIterations">The number of burn-in iterations.</param>
/// <param name="resetRandomGenerator">Reset the random number generator for each document.</param>
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Oct 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sprinkle it with dots at the end of sentence. #Resolved

public static Vector<float> ToLdaTopicVector(this Vector<float> input,
int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic,
Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum,
Single beta = LatentDirichletAllocationEstimator.Defaults.Beta,
int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep,
int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations,
int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval,
int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads,
int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken,
int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic,
int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations,
bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator,
LdaFitResult.OnFit onFit = null)
{
Contracts.CheckValue(input, nameof(input));
return new ImplVector(input,
new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic,
numBurninIterations, resetRandomGenerator, Wrap(onFit)));
}
}
}
Loading