diff --git a/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs b/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs
index f8fd1b476e..7bff665ba6 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlowModel.cs
@@ -18,6 +18,7 @@ public sealed class TensorFlowModel : IDisposable
{
internal Session Session { get; }
internal string ModelPath { get; }
+ internal bool TreatOutputAsBatched { get; }
private readonly IHostEnvironment _env;
@@ -27,10 +28,12 @@ public sealed class TensorFlowModel : IDisposable
/// An object.
/// TensorFlow session object.
/// Location of the model from where was loaded.
- internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation)
+ /// If the first dimension of the output is unknown, should it be treated as batched or not.
+ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation, bool treatOutputAsBatched = true)
{
Session = session;
ModelPath = modelLocation;
+ TreatOutputAsBatched = treatOutputAsBatched;
_env = env;
_disposed = false;
}
@@ -40,7 +43,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca
///
public DataViewSchema GetModelSchema()
{
- return TensorFlowUtils.GetModelSchema(_env, Session.graph);
+ return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched);
}
///
@@ -49,7 +52,7 @@ public DataViewSchema GetModelSchema()
///
public DataViewSchema GetInputSchema()
{
- return TensorFlowUtils.GetModelSchema(_env, Session.graph, "Placeholder");
+ return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched, "Placeholder");
}
///
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
index 2ad63321d4..372d4b1029 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
@@ -35,5 +35,31 @@ public static class TensorflowCatalog
///
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation)
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation);
+
+ ///
+ /// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
+ /// using .
+ /// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
+ /// also holds references to unmanaged resources that need to be freed either with an explicit
+ /// call to Dispose() or implicitly by declaring the variable with the "using" syntax/>
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// The transform's catalog.
+ /// Location of the TensorFlow model.
+ /// If the first dimension of the output is unknown, should it be treated as batched or not.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation, bool treatOutputAsBatched)
+ => TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation, treatOutputAsBatched);
}
}
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index f5e85f158f..1537ab10cf 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -45,6 +45,7 @@ public sealed class TensorFlowTransformer : RowToRowTransformerBase, IDisposable
private readonly string _savedModelPath;
private readonly bool _isTemporarySavedModel;
private readonly bool _addBatchDimensionInput;
+ private readonly bool _treatOutputAsBatched;
internal readonly Session Session;
internal readonly Runner Runner;
internal readonly DataViewType[] OutputTypes;
@@ -71,8 +72,9 @@ private static VersionInfo GetVersionInfo()
modelSignature: "TENSFLOW",
//verWrittenCur: 0x00010001, // Initial
//verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel.
- verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs.
- verReadableCur: 0x00010003,
+ //verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs.
+ verWrittenCur: 0x00010004, // Added Support for treating batch as output or not.
+ verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TensorFlowTransformer).Assembly.FullName);
@@ -82,16 +84,17 @@ private static VersionInfo GetVersionInfo()
/// Transform for scoring Tensorflow models. Input data column names/types must exactly match
/// all model input names. Only the output columns specified will be generated.
/// This convenience method avoids reloading of TensorFlow model.
- /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema.
+ /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema.
///
/// The environment to use.
- /// object created with .
+ /// object created with .
/// The output columns to generate. Names must match model specifications. Data types are inferred from model.
/// The name of the input data columns. Must match model's input names. If set to , the value of the will be used as source.
/// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.
- internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false)
- : this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput)
+ /// If the first dimension of the output is unknown, should it be treated as batched or not.
+ internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string outputColumnName, string inputColumnName = null, bool addBatchDimensionInput = false, bool treatOutputAsBatched = true)
+ : this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched)
{
}
@@ -99,16 +102,17 @@ internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo
/// Transform for scoring Tensorflow models. Input data column names/types must exactly match
/// all model input names. Only the output columns specified will be generated.
/// This convenience method avoids reloading of TensorFlow model.
- /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema.
+ /// It is useful in a situation where user has already loaded TensorFlow model using for inspecting model schema.
///
/// The environment to use.
- /// object created with .
+ /// object created with .
/// The name of the input data columns. Must match model's input names.
/// The output columns to generate. Names must match model specifications. Data types are inferred from model.
/// Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.
- internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false)
- : this(env, tfModelInfo.Session, outputColumnNames, inputColumnNames, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput)
+ /// If the first dimension of the output is unknown, should it be treated as batched or not.
+ internal TensorFlowTransformer(IHostEnvironment env, TensorFlowModel tfModelInfo, string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false, bool treatOutputAsBatched = true)
+ : this(env, tfModelInfo.Session, outputColumnNames, inputColumnNames, IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched)
{
}
@@ -122,6 +126,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
// *** Binary format ***
// byte: indicator for frozen models
// byte: indicator for adding batch dimension in input
+ // byte: indicator for treating output as batched
// stream: tensorFlow model.
// int: number of input columns
// for each input column
@@ -129,13 +134,13 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
// int: number of output columns
// for each output column
// int: id of output column name
- GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput);
+ GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool treatOutputAsBatched);
if (isFrozen)
{
byte[] modelBytes = null;
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
throw env.ExceptDecode();
- return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput);
+ return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched);
}
var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(TensorFlowTransformer) + "_" + Guid.NewGuid()));
@@ -164,7 +169,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
}
});
- return new TensorFlowTransformer(env, GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, addBatchDimensionInput);
+ return new TensorFlowTransformer(env, GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched);
}
catch (Exception)
{
@@ -236,7 +241,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput)
+ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool treatOutputAsBatched)
{
isFrozen = true;
bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002;
@@ -248,6 +253,11 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out
if (isAddingBatchDimensionSupported)
addBatchDimensionInput = ctx.Reader.ReadBoolByte();
+ treatOutputAsBatched = true;
+ bool isTreatingOutputAsBatchedSupported = ctx.Header.ModelVerReadable >= 0x00010004;
+ if (isTreatingOutputAsBatchedSupported)
+ treatOutputAsBatched = ctx.Reader.ReadBoolByte();
+
var numInputs = ctx.Reader.ReadInt32();
env.CheckDecode(numInputs > 0);
inputs = new string[numInputs];
@@ -267,7 +277,7 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out
internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] outputColumnNames,
string[] inputColumnNames, string savedModelPath, bool isTemporarySavedModel,
- bool addBatchDimensionInput, int batchSize = 1, TensorFlowEstimator.Options options = null, IDataView input = null)
+ bool addBatchDimensionInput, int batchSize = 1, TensorFlowEstimator.Options options = null, IDataView input = null, bool treatOutputAsBatched = true)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowTransformer)))
{
@@ -279,11 +289,12 @@ internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] o
_isTemporarySavedModel = isTemporarySavedModel;
Session = session;
_addBatchDimensionInput = addBatchDimensionInput;
+ _treatOutputAsBatched = treatOutputAsBatched;
Inputs = inputColumnNames;
Outputs = outputColumnNames;
tf.compat.v1.disable_eager_execution();
- (TFOutputTypes, OutputTypes, TFOutputOperations) = GetOutputInfo(Host, Session, Outputs);
+ (TFOutputTypes, OutputTypes, TFOutputOperations) = GetOutputInfo(Host, Session, Outputs, treatOutputAsBatched);
(TFInputTypes, TFInputShapes, TFInputOperations) = GetInputInfo(Host, Session, Inputs, batchSize);
TFInputNodes = new TF_Output[Inputs.Length];
@@ -359,7 +370,7 @@ internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status
return new TensorShape(dims.Select(x => (int)x).ToArray());
}
- internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs)
+ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Operation, int)[]) GetOutputInfo(IHost host, Session session, string[] outputs, bool treatOutputAsBatched)
{
var tfOutputTypes = new TF_DataType[outputs.Length];
var outputTypes = new DataViewType[outputs.Length];
@@ -384,7 +395,12 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera
// If there are other dimension that are unknown the transformer will return a variable length vector.
// This is the work around in absence of reshape transformer.
var idims = shape.dims;
- int[] dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0];
+
+ int[] dims = idims;
+ if (treatOutputAsBatched)
+ {
+ dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0];
+ }
for (int j = 0; j < dims.Length; j++)
dims[j] = dims[j] == -1 ? 0 : dims[j];
if (dims == null || dims.Length == 0)
@@ -415,6 +431,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
// *** Binary format ***
// byte: indicator for frozen models
// byte: indicator for adding batch dimension in input
+ // byte: indicator for treating output as batched
// stream: tensorFlow model.
// int: number of input columns
// for each input column
@@ -425,6 +442,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
var isFrozen = string.IsNullOrEmpty(_savedModelPath);
ctx.Writer.WriteBoolByte(isFrozen);
ctx.Writer.WriteBoolByte(_addBatchDimensionInput);
+ ctx.Writer.WriteBoolByte(_treatOutputAsBatched);
if (isFrozen)
{
using (var status = new Status())
@@ -876,6 +894,15 @@ internal sealed class Options : TransformInputBase
///
[Argument(ArgumentType.AtMostOnce, HelpText = "Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].", SortOrder = 16)]
public bool AddBatchDimensionInputs = false;
+
+ ///
+ /// If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unknown length when this is false.
+ ///
+ ///
+ /// This parameter is used to deal with models that have unknown output shape and it needs to be interpreted in ML.NET as a vector of unknown length and not as a batch dimension.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unknown length when this is false.", SortOrder = 17)]
+ public bool TreatOutputAsBatched = true;
}
private readonly IHost _host;
@@ -897,7 +924,7 @@ internal TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, s
}
internal TensorFlowEstimator(IHostEnvironment env, Options options)
- : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation))
+ : this(env, options, TensorFlowUtils.LoadTensorFlowModel(env, options.ModelLocation, options.TreatOutputAsBatched))
{
}
@@ -906,20 +933,23 @@ internal TensorFlowEstimator(IHostEnvironment env, Options options, TensorFlowMo
_host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator));
_options = options;
_tensorFlowModel = tensorFlowModel;
+ if (!tensorFlowModel.TreatOutputAsBatched)
+ _options.TreatOutputAsBatched = tensorFlowModel.TreatOutputAsBatched;
tensorFlowModel.Session.graph.as_default();
- var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns);
+ var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, _options.InputColumns);
_tfInputTypes = inputTuple.tfInputTypes;
- var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns);
+ var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, _options.OutputColumns, _options.TreatOutputAsBatched);
_outputTypes = outputTuple.outputTypes;
}
- private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput)
+ private static Options CreateArguments(TensorFlowModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput, bool treatOutputAsBatched = true)
{
var options = new Options();
options.ModelLocation = tensorFlowModel.ModelPath;
options.InputColumns = inputColumnName;
options.OutputColumns = outputColumnNames;
options.AddBatchDimensionInputs = addBatchDimensionInput;
+ options.TreatOutputAsBatched = treatOutputAsBatched;
return options;
}
@@ -959,7 +989,7 @@ public TensorFlowTransformer Fit(IDataView input)
if (_transformer == null)
{
_transformer = new TensorFlowTransformer(_host, _tensorFlowModel.Session, _options.OutputColumns, _options.InputColumns,
- IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false, _options.AddBatchDimensionInputs);
+ IsSavedModel(_host, _options.ModelLocation) ? _options.ModelLocation : null, false, _options.AddBatchDimensionInputs, treatOutputAsBatched: _options.TreatOutputAsBatched);
}
// Validate input schema.
_transformer.GetOutputSchema(input.Schema);
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
index 805aedcb3d..8fbbd772a0 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
@@ -32,7 +32,7 @@ internal static class TensorFlowUtils
///
internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
- internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, string opType = null)
+ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, bool treatOutputAsBatched, string opType = null)
{
var schemaBuilder = new DataViewSchema.Builder();
foreach (Operation op in graph)
@@ -79,7 +79,7 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
// Construct the final ML.NET type of a Tensorflow variable.
var tensorShape = op.output.TensorShape.dims;
- if(tensorShape == null)
+ if (tensorShape == null)
{
// primitive column type
schemaBuilder.AddColumn(op.name, mlType, metadataBuilder.ToAnnotations());
@@ -90,7 +90,24 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
DataViewType columnType = new VectorDataViewType(mlType);
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))
- columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
+ // treatOutputAsBatched == true means that if the first dimension is greater
+ // than 0 we take the tensor shape as is. If the first value is less then 0, we treat it as the batch input so we can
+ // ignore it for the shape of the ML.NET vector. I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as
+ // batch input, and so the ML.NET data type will be a vector of length 5.
+ if (treatOutputAsBatched)
+ {
+ columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
+ }
+ // When treatOutputAsBatched is false, if the first value is less than 0 we want to set it to 0. TensorFlow
+ // represents an unknown size as -1, but ML.NET represents it as 0 so we need to convert it.
+ // I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as a dimension of unknown length, and so the ML.NET
+ // data type will be a vector of 2 dimensions, where the first dimension is unknown and the second has a length of 5.
+ else
+ {
+ if (tensorShape[0] < 0)
+ tensorShape[0] = 0;
+ columnType = new VectorDataViewType(mlType, tensorShape);
+ }
schemaBuilder.AddColumn(op.name, columnType, metadataBuilder.ToAnnotations());
}
@@ -108,10 +125,11 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
///
/// The environment to use.
/// Model to load.
- internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
+ /// If the first dimension of the output is unknown, should it be treated as batched or not.
+ internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true)
{
- using var model = LoadTensorFlowModel(env, modelPath);
- return GetModelSchema(env, model.Session.graph);
+ using var model = LoadTensorFlowModel(env, modelPath, treatOutputAsBatched);
+ return GetModelSchema(env, model.Session.graph, treatOutputAsBatched);
}
///
@@ -119,11 +137,12 @@ internal static DataViewSchema GetModelSchema(IHostEnvironment env, string model
///
/// The environment to use.
/// The model to load.
+ /// If the first dimension of the output is unknown, should it be treated as batched or not.
///
- internal static TensorFlowModel LoadTensorFlowModel(IHostEnvironment env, string modelPath)
+ internal static TensorFlowModel LoadTensorFlowModel(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true)
{
var session = GetSession(env, modelPath);
- return new TensorFlowModel(env, session, modelPath);
+ return new TensorFlowModel(env, session, modelPath, treatOutputAsBatched: treatOutputAsBatched);
}
internal static PrimitiveDataViewType Tf2MlNetType(TF_DataType type)
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index 429ff6bb78..7253e4533f 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -23613,6 +23613,15 @@
"SortOrder": 16.0,
"IsNullable": false,
"Default": false
+ },
+ {
+ "Name": "TreatOutputAsBatched",
+ "Type": "Bool",
+ "Desc": "If the first dimension of the output is unknown, should it be treated as batched or not. e.g. output = [-1] will be read as a vector of unknown length when this is false.",
+ "Required": false,
+ "SortOrder": 17.0,
+ "IsNullable": false,
+ "Default": true
}
],
"Outputs": [
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 175e958f1b..3508cfe481 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -1152,7 +1152,6 @@ public void TensorFlowGettingSchemaMultipleTimes()
}
}
-
[TensorFlowFact]
public void TensorFlowTransformCifarInvalidShape()
{
diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
index 8aba9b08a5..ff6dbd456f 100644
--- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
@@ -184,6 +184,52 @@ public void TestTensorFlow()
}
}
+ [TensorFlowFact]
+ public void TreatOutputAsBatched()
+ {
+ var modelLocation = "cifar_model/frozen_model.pb";
+
+ var mlContext = new MLContext(seed: 1);
+ var imageHeight = 32;
+ var imageWidth = 32;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+
+ var data = ML.Data.LoadFromTextFile(dataFile, new[] {
+ new TextLoader.Column("imagePath", DataKind.String, 0),
+ new TextLoader.Column("name", DataKind.String, 1)
+ });
+
+ // Note that CamelCase column names are there to match the TF graph node names.
+ // Check and make sure save/load work correctly for the new TreatOutputAsBatched value.
+ var pipe = ML.Transforms.LoadImages("Input", imageFolder, "imagePath")
+ .Append(ML.Transforms.ResizeImages("Input", imageHeight, imageWidth))
+ .Append(ML.Transforms.ExtractPixels("Input", interleavePixelColors: true))
+ .Append(ML.Model.LoadTensorFlowModel(modelLocation, false).ScoreTensorFlowModel("Output", "Input"));
+
+ TestEstimatorCore(pipe, data);
+ var schema = pipe.Fit(data).Transform(data).Schema;
+
+ // The dimensions of the output with treatOutputAsBatched set to false should be * 10
+ // as the first dimension of -1 is treated as an unknown dimension.
+ Assert.Equal(new VectorDataViewType(NumberDataViewType.Single, 0, 10), schema["Output"].Type);
+
+ // Note that CamelCase column names are there to match the TF graph node names.
+ // Test with TreatOutputAsBatched set to default value of true.
+ pipe = ML.Transforms.LoadImages("Input", imageFolder, "imagePath")
+ .Append(ML.Transforms.ResizeImages("Input", imageHeight, imageWidth))
+ .Append(ML.Transforms.ExtractPixels("Input", interleavePixelColors: true))
+ .Append(ML.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("Output", "Input"));
+
+ TestEstimatorCore(pipe, data);
+ schema = pipe.Fit(data).Transform(data).Schema;
+
+ // The dimensions of the output with treatOutputAsBatched set to true should be 10
+ // as the first dimension of -1 is treated as the batch dimension.
+ Assert.Equal(new VectorDataViewType(NumberDataViewType.Single, 10), schema["Output"].Type);
+
+ }
+
[TensorFlowFact]
public void TestTensorFlowWithSchema()
{