Skip to content

Commit

Permalink
ProduceWordBags Onnx Export Fix (#5435)
Browse files Browse the repository at this point in the history
* fix for issue

* fix documentation

* aligning test

* adding back line

* aligning fix

Co-authored-by: Keren Fuentes <[email protected]>
  • Loading branch information
Lynx1820 and Keren Fuentes authored Oct 15, 2020
1 parent afba0bd commit 82d4bb7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/TextCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ public static CustomStopWordsRemovingEstimator RemoveStopWords(this TransformsCa
=> new CustomStopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, stopwords);

/// <summary>
/// Create a <see cref="WordHashBagEstimator"/>, which maps the column specified in <paramref name="inputColumnName"/>
/// Create a <see cref="WordBagEstimator"/>, which maps the column specified in <paramref name="inputColumnName"/>
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
/// </summary>
/// <remarks>
Expand Down Expand Up @@ -363,7 +363,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting);

/// <summary>
/// Create a <see cref="WordHashBagEstimator"/>, which maps the multiple columns specified in <paramref name="inputColumnNames"/>
/// Create a <see cref="WordBagEstimator"/>, which maps the multiple columns specified in <paramref name="inputColumnNames"/>
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
/// </summary>
/// <remarks>
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,10 @@ public void SaveAsOnnx(OnnxContext ctx)
string[] separators = column.SeparatorsArray.Select(c => c.ToString()).ToArray();
tokenizerNode.AddAttribute("separators", separators);

opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name);
var squeezeNode = ctx.CreateNode(opType, intermediateVar, squeezeOutput, ctx.GetNodeName(opType), "");
squeezeNode.AddAttribute("axes", new long[] { 1 });
opType = "Reshape";
var shape = ctx.AddInitializer(new long[] { 1, -1 }, new long[] { 2 }, "Shape");
var reshapeOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, 1), column.Name);
var reshapeNode = ctx.CreateNode(opType, new[] { intermediateVar, shape }, new[] { reshapeOutput }, ctx.GetNodeName(opType), "");
}
}
}
Expand Down
11 changes: 7 additions & 4 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1323,9 +1323,12 @@ public void NgramOnnxConversionTest(
weighting: weighting)),

mlContext.Transforms.Text.ProduceWordBags("Tokens", "Text",
ngramLength: ngramLength,
useAllLengths: useAllLength,
weighting: weighting)
ngramLength: ngramLength,
useAllLengths: useAllLength,
weighting: weighting),

mlContext.Transforms.Text.TokenizeIntoWords("Tokens0", "Text")
.Append(mlContext.Transforms.Text.ProduceWordBags("Tokens", "Tokens0"))
};

for (int i = 0; i < pipelines.Length; i++)
Expand All @@ -1346,7 +1349,7 @@ public void NgramOnnxConversionTest(
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxFilePath, gpuDeviceId: _gpuDeviceId, fallbackToCpu: _fallbackToCpu);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
var columnName = i == pipelines.Length - 1 ? "Tokens" : "NGrams";
var columnName = i >= pipelines.Length - 2 ? "Tokens" : "NGrams";
CompareResults(columnName, columnName, transformedData, onnxResult, 3);

VBuffer<ReadOnlyMemory<char>> mlNetSlots = default;
Expand Down

0 comments on commit 82d4bb7

Please sign in to comment.