-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert LdaTransform to IEstimator/ITransformer API #1410
Merged
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
acd8964
preparing to convert LDATransform to the IEstimator/ITransformer para…
abgoswam cd2f20c
Added ColumnInfo and TransformInfo
abgoswam c289ed1
existing tests pass after refactor
abgoswam 8dfb527
fix re-arragement of arguments in entrypoints
abgoswam bcb3b0d
refactor + cleanup of TopicSummary
abgoswam 1d39408
enabled OnFit to return LdaState
abgoswam 9ff60f7
fixed merge conflicts
abgoswam ad36d2f
fix build issues after merge; fix review comments
abgoswam b0e0375
taking care of review comments - 1
abgoswam b0422e4
Merge branch 'master' into abgoswam/LDA_conversion
abgoswam 7bc6e2b
merge with master; re-enable LDA tests before taking care of addition…
abgoswam e42c5e4
review comments - 2. rename LdaTransform to LdaTransformer
abgoswam c099d4a
review comments - 3. throw ExceptSchemaMismatch; default values; ToIm…
abgoswam a1d14ed
review comments - 4. output column; expression body definition
abgoswam 3f39a04
review comments - 4; added a basic test that exercises TestEstimatorC…
abgoswam 57cd1c5
review comments - 5; rename _exes to _columns; preparing changes for …
abgoswam d4a4283
review comments - 6; make training a private static method.
abgoswam c7fb50a
review comments - 7; make Create() method private
abgoswam d7660ca
review comments - 8; added internal constructor for ColumnInfo()
abgoswam e0d501b
review comments - 9; fixed types for Single, Float
abgoswam c91afbb
review comments - 10; made LdaState internal, expose LDA summary info…
abgoswam 4238fa1
review comments - 11; added a command line unit test
abgoswam 34bb2e9
review comments - 12; no-op when there is no data (maml command line …
abgoswam 8b70ab1
review comments - 13; added mlcontext extension for LDA
abgoswam b3c1284
Merge branch 'master' into abgoswam/LDA_conversion
abgoswam b6e4028
review comments - 14; code refactor; avoid using abbreviation(Lda); r…
abgoswam 5397de5
review comments - 15; schema changes
abgoswam edd60af
review comments - 15; provide better user-facing description. renamed…
abgoswam b073038
review comments - 16; fix build break
abgoswam 49da3ee
review comments - 16; include words in LdaSummary (this also resolves…
abgoswam 0724290
fixing merge conflicts
abgoswam 5073baa
fixing merge conflicts
abgoswam d1481f8
fix build break because of manifest changes
abgoswam 65125d4
Merge branch 'master' into abgoswam/LDA_conversion
abgoswam b869d7f
updated to latest interface changes
abgoswam 62955a8
review comments - 17; named tuple, namespace change, fix CheckParams
abgoswam 40333a7
review comments - 18; ImmutableArray
abgoswam 850856b
review comments - 19; update summary text; remove dup code in constru…
abgoswam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
174 changes: 174 additions & 0 deletions
174
src/Microsoft.ML.Transforms/Text/LdaStaticExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,174 @@ | ||||||||
// Licensed to the .NET Foundation under one or more agreements. | ||||||||
// The .NET Foundation licenses this file to you under the MIT license. | ||||||||
// See the LICENSE file in the project root for more information. | ||||||||
|
||||||||
using Microsoft.ML.Core.Data; | ||||||||
using Microsoft.ML.Runtime; | ||||||||
using Microsoft.ML.Runtime.Data; | ||||||||
using Microsoft.ML.StaticPipe.Runtime; | ||||||||
using Microsoft.ML.Transforms.Text; | ||||||||
using System; | ||||||||
using System.Collections.Generic; | ||||||||
|
||||||||
namespace Microsoft.ML.StaticPipe | ||||||||
{ | ||||||||
/// <summary> | ||||||||
/// Information on the result of fitting a LDA transform. | ||||||||
/// </summary> | ||||||||
public sealed class LdaFitResult | ||||||||
{ | ||||||||
/// <summary> | ||||||||
/// For user defined delegates that accept instances of the containing type. | ||||||||
/// </summary> | ||||||||
/// <param name="result"></param> | ||||||||
public delegate void OnFit(LdaFitResult result); | ||||||||
|
||||||||
public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary; | ||||||||
public LdaFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary) | ||||||||
{ | ||||||||
LdaTopicSummary = ldaTopicSummary; | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
public static class LdaStaticExtensions | ||||||||
{ | ||||||||
private struct Config | ||||||||
{ | ||||||||
public readonly int NumTopic; | ||||||||
public readonly Single AlphaSum; | ||||||||
public readonly Single Beta; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
public readonly int MHStep; | ||||||||
public readonly int NumIter; | ||||||||
public readonly int LikelihoodInterval; | ||||||||
public readonly int NumThread; | ||||||||
public readonly int NumMaxDocToken; | ||||||||
public readonly int NumSummaryTermPerTopic; | ||||||||
public readonly int NumBurninIter; | ||||||||
public readonly bool ResetRandomGenerator; | ||||||||
|
||||||||
public readonly Action<LatentDirichletAllocationTransformer.LdaSummary> OnFit; | ||||||||
|
||||||||
public Config(int numTopic, Single alphaSum, Single beta, int mhStep, int numIter, int likelihoodInterval, | ||||||||
int numThread, int numMaxDocToken, int numSummaryTermPerTopic, int numBurninIter, bool resetRandomGenerator, | ||||||||
Action<LatentDirichletAllocationTransformer.LdaSummary> onFit) | ||||||||
{ | ||||||||
NumTopic = numTopic; | ||||||||
AlphaSum = alphaSum; | ||||||||
Beta = beta; | ||||||||
MHStep = mhStep; | ||||||||
NumIter = numIter; | ||||||||
LikelihoodInterval = likelihoodInterval; | ||||||||
NumThread = numThread; | ||||||||
NumMaxDocToken = numMaxDocToken; | ||||||||
NumSummaryTermPerTopic = numSummaryTermPerTopic; | ||||||||
NumBurninIter = numBurninIter; | ||||||||
ResetRandomGenerator = resetRandomGenerator; | ||||||||
|
||||||||
OnFit = onFit; | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LdaFitResult.OnFit onFit) | ||||||||
{ | ||||||||
if (onFit == null) | ||||||||
return null; | ||||||||
|
||||||||
return ldaTopicSummary => onFit(new LdaFitResult(ldaTopicSummary)); | ||||||||
} | ||||||||
|
||||||||
private interface ILdaCol | ||||||||
{ | ||||||||
PipelineColumn Input { get; } | ||||||||
Config Config { get; } | ||||||||
} | ||||||||
|
||||||||
private sealed class ImplVector : Vector<float>, ILdaCol | ||||||||
{ | ||||||||
public PipelineColumn Input { get; } | ||||||||
public Config Config { get; } | ||||||||
public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input) | ||||||||
{ | ||||||||
Input = input; | ||||||||
Config = config; | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
private sealed class Rec : EstimatorReconciler | ||||||||
{ | ||||||||
public static readonly Rec Inst = new Rec(); | ||||||||
|
||||||||
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, | ||||||||
PipelineColumn[] toOutput, | ||||||||
IReadOnlyDictionary<PipelineColumn, string> inputNames, | ||||||||
IReadOnlyDictionary<PipelineColumn, string> outputNames, | ||||||||
IReadOnlyCollection<string> usedNames) | ||||||||
{ | ||||||||
var infos = new LatentDirichletAllocationTransformer.ColumnInfo[toOutput.Length]; | ||||||||
Action<LatentDirichletAllocationTransformer> onFit = null; | ||||||||
for (int i = 0; i < toOutput.Length; ++i) | ||||||||
{ | ||||||||
var tcol = (ILdaCol)toOutput[i]; | ||||||||
|
||||||||
infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], | ||||||||
tcol.Config.NumTopic, | ||||||||
tcol.Config.AlphaSum, | ||||||||
tcol.Config.Beta, | ||||||||
tcol.Config.MHStep, | ||||||||
tcol.Config.NumIter, | ||||||||
tcol.Config.LikelihoodInterval, | ||||||||
tcol.Config.NumThread, | ||||||||
tcol.Config.NumMaxDocToken, | ||||||||
tcol.Config.NumSummaryTermPerTopic, | ||||||||
tcol.Config.NumBurninIter, | ||||||||
tcol.Config.ResetRandomGenerator); | ||||||||
|
||||||||
if (tcol.Config.OnFit != null) | ||||||||
{ | ||||||||
int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call. | ||||||||
onFit += tt => tcol.Config.OnFit(tt.GetLdaDetails(ii)); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
var est = new LatentDirichletAllocationEstimator(env, infos); | ||||||||
if (onFit == null) | ||||||||
return est; | ||||||||
|
||||||||
return est.WithOnFitDelegate(onFit); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
/// <include file='doc.xml' path='doc/members/member[@name="LightLDA"]/*' /> | ||||||||
/// <param name="input">A vector of floats representing the document.</param> | ||||||||
/// <param name="numTopic">The number of topics.</param> | ||||||||
/// <param name="alphaSum">Dirichlet prior on document-topic vectors.</param> | ||||||||
/// <param name="beta">Dirichlet prior on vocab-topic vectors.</param> | ||||||||
/// <param name="mhstep">Number of Metropolis Hasting step.</param> | ||||||||
/// <param name="numIterations">Number of iterations.</param> | ||||||||
/// <param name="likelihoodInterval">Compute log likelihood over local dataset on this iteration interval.</param> | ||||||||
/// <param name="numThreads">The number of training threads. Default value depends on number of logical processors.</param> | ||||||||
/// <param name="numMaxDocToken">The threshold of maximum count of tokens per doc.</param> | ||||||||
/// <param name="numSummaryTermPerTopic">The number of words to summarize the topic.</param> | ||||||||
/// <param name="numBurninIterations">The number of burn-in iterations.</param> | ||||||||
/// <param name="resetRandomGenerator">Reset the random number generator for each document.</param> | ||||||||
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param> | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sprinkle it with dots at the end of sentence. #Resolved |
||||||||
public static Vector<float> ToLdaTopicVector(this Vector<float> input, | ||||||||
int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, | ||||||||
Single alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, | ||||||||
Single beta = LatentDirichletAllocationEstimator.Defaults.Beta, | ||||||||
int mhstep = LatentDirichletAllocationEstimator.Defaults.Mhstep, | ||||||||
int numIterations = LatentDirichletAllocationEstimator.Defaults.NumIterations, | ||||||||
int likelihoodInterval = LatentDirichletAllocationEstimator.Defaults.LikelihoodInterval, | ||||||||
int numThreads = LatentDirichletAllocationEstimator.Defaults.NumThreads, | ||||||||
int numMaxDocToken = LatentDirichletAllocationEstimator.Defaults.NumMaxDocToken, | ||||||||
int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, | ||||||||
int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, | ||||||||
bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator, | ||||||||
LdaFitResult.OnFit onFit = null) | ||||||||
{ | ||||||||
Contracts.CheckValue(input, nameof(input)); | ||||||||
return new ImplVector(input, | ||||||||
new Config(numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, | ||||||||
numBurninIterations, resetRandomGenerator, Wrap(onFit))); | ||||||||
} | ||||||||
} | ||||||||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.