Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue 5020, allow ML.NET to load tf model with primitive input and output column #5468

Merged
merged 10 commits into from
Nov 5, 2020
2 changes: 1 addition & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
<MicrosoftExtensionsTestPackageVersion>3.0.1</MicrosoftExtensionsTestPackageVersion>
<MicrosoftMLTestDatabasesPackageVersion>0.0.6-test</MicrosoftMLTestDatabasesPackageVersion>
<MicrosoftMLTestModelsPackageVersion>0.0.6-test</MicrosoftMLTestModelsPackageVersion>
<MicrosoftMLTensorFlowTestModelsVersion>0.0.12-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLOnnxTestModelsVersion>0.0.6-test</MicrosoftMLOnnxTestModelsVersion>
<SystemDataSqlClientVersion>4.6.1</SystemDataSqlClientVersion>
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorTypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Text;
using Microsoft.ML.Internal.Utilities;
using Tensorflow;
using Utils = Microsoft.ML.Internal.Utilities.Utils;
Expand All @@ -14,6 +15,16 @@ internal static class TensorTypeExtensions
{
public static void ToScalar<T>(this Tensor tensor, ref T dst) where T : unmanaged
{
//In ML.NET we are using ReadOnlyMemory<Char> to store string data but ReadOnlyMemory<Char>
//is not valid data type for tensorflow.net and exception will thrown if we call as_dtype method
//so we specially deal with string type here.
//Get string data first then convert to ReadOnlyMemory<Char> and assign value to dst.
if (typeof(T) == typeof(ReadOnlyMemory<char>))
{
dst = (T)(object)tensor.StringData()[0].AsMemory();
return;
}

if (typeof(T).as_dtype() != tensor.dtype)
throw new NotSupportedException();

Expand Down
26 changes: 11 additions & 15 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,11 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera
// If there are other dimension that are unknown the transformer will return a variable length vector.
// This is the work around in absence of reshape transformer.
var idims = shape.dims;
int[] dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new[] { 0 };
int[] dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0];
for (int j = 0; j < dims.Length; j++)
dims[j] = dims[j] == -1 ? 0 : dims[j];
if (dims == null || dims.Length == 0)
{
dims = new[] { 1 };
outputTypes[i] = Tf2MlNetType(tfOutputType);
}
else
Expand Down Expand Up @@ -503,20 +502,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
throw Host.Except("Variable length input columns not supported");

_isInputVector[i] = type is VectorDataViewType;
if (!_isInputVector[i])
throw Host.Except("Non-vector columns are not supported and should be loaded as vector columns of size 1");
vecType = (VectorDataViewType)type;
var expectedType = Tf2MlNetType(_parent.TFInputTypes[i]);
if (type.GetItemType() != expectedType)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], expectedType.ToString(), type.ToString());
var originalShape = _parent.TFInputShapes[i];
var shape = originalShape.dims;

var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
if (shape == null || (shape.Length == 0))
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
_fullySpecifiedShapes[i] = new TensorShape();
else
{
vecType = (VectorDataViewType)type;
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
// If the column is one dimension we make sure that the total size of the TF shape matches.
// Compute the total size of the known dimensions of the shape.
int valCount = 1;
Expand Down Expand Up @@ -561,7 +558,10 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :

if (_parent._addBatchDimensionInput)
{
var l = new int[_fullySpecifiedShapes[i].ndim + 1];
// ndim of default TensorShape is -1, make originDim to 0 in this case.
// after addBatchDimension, input column will be changed: type -> type[]
var originDim = _fullySpecifiedShapes[i].ndim < 0 ? 0 : _fullySpecifiedShapes[i].ndim;
var l = new int[originDim + 1];
l[0] = 1;
for (int ishape = 1; ishape < l.Length; ishape++)
l[ishape] = _fullySpecifiedShapes[i].dims[ishape - 1];
Expand Down Expand Up @@ -729,11 +729,10 @@ public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape)
{
_srcgetter = input.GetGetter<T>(input.Schema[colIndex]);
_tfShape = tfShape;
long size = 0;
long size = 1;
_position = 0;
if (tfShape.dims.Length != 0)
if (tfShape.dims != null && tfShape.dims.Length != 0)
{
size = 1;
foreach (var dim in tfShape.dims)
size *= dim;
}
Expand All @@ -744,8 +743,7 @@ public Tensor GetTensor()
{
var scalar = default(T);
_srcgetter(ref scalar);
var tensor = new Tensor(new[] { scalar });
tensor.set_shape(_tfShape);
var tensor = TensorFlowUtils.CastDataAndReturnAsTensor(scalar);
return tensor;
}

Expand Down Expand Up @@ -928,8 +926,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
var input = _options.InputColumns[i];
if (!inputSchema.TryFindColumn(input, out var col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
if (!(col.Kind == SchemaShape.Column.VectorKind.Vector))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
var expectedType = Tf2MlNetType(_tfInputTypes[i]);
if (col.ItemType != expectedType)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
Expand Down
30 changes: 30 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,36 @@ internal static Tensor CastDataAndReturnAsTensor<T>(T[] data, TensorShape tfShap
return new Tensor(new NDArray(data, tfShape));
}

internal static Tensor CastDataAndReturnAsTensor<T>(T data)
{
if (typeof(T) == typeof(sbyte))
return new Tensor((sbyte)(object)data, TF_DataType.TF_INT8);
else if (typeof(T) == typeof(long))
return new Tensor((long)(object)data, TF_DataType.TF_INT64);
else if (typeof(T) == typeof(Int32))
return new Tensor((Int32)(object)data, TF_DataType.TF_INT32);
else if (typeof(T) == typeof(Int16))
return new Tensor((Int16)(object)data, TF_DataType.TF_INT16);
else if (typeof(T) == typeof(byte))
return new Tensor((byte)(object)data, TF_DataType.TF_UINT8);
else if (typeof(T) == typeof(ulong))
return new Tensor((ulong)(object)data, TF_DataType.TF_UINT64);
else if (typeof(T) == typeof(UInt32))
return new Tensor((UInt32)(object)data, TF_DataType.TF_UINT32);
else if (typeof(T) == typeof(UInt16))
return new Tensor((UInt16)(object)data, TF_DataType.TF_UINT16);
else if (typeof(T) == typeof(bool))
return new Tensor((bool)(object)data, TF_DataType.TF_BOOL);
else if (typeof(T) == typeof(float))
return new Tensor((float)(object)data, TF_DataType.TF_FLOAT);
else if (typeof(T) == typeof(double))
return new Tensor((double)(object)data, TF_DataType.TF_DOUBLE);
else if (typeof(T) == typeof(ReadOnlyMemory<char>))
return new Tensor(data.ToString());

throw new ArgumentException($"Unsupported data type of {typeof(T)} to convert to Tensor.");
}

/// <summary>
/// Use the runner class to easily configure inputs, outputs and targets to be passed to the session runner.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,20 @@ class TextOutput
public string[] BOut { get; set; }
}

class PrimitiveInput
{
[LoadColumn(0, 1)]
public string input1;

[LoadColumn(1, 2)]
public string input2;
}

class PrimitiveOutput
{
public string string_merge { get; set; }
}

[TensorFlowFact]
public void TensorFlowStringTest()
{
Expand All @@ -1286,6 +1300,32 @@ public void TensorFlowStringTest()
Assert.Equal(string.Join(" ", input.B).Replace("/", " "), textOutput.BOut[0]);
}

[TensorFlowFact]
public void TensorFlowPrimitiveInputTest()
{
using var tensorFlowModel = _mlContext.Model.LoadTensorFlowModel(@"model_primitive_input_test");
var schema = tensorFlowModel.GetModelSchema();
Assert.True(schema.TryGetColumnIndex("input1", out var colIndex));
Assert.True(schema.TryGetColumnIndex("input2", out colIndex));

var dataview = _mlContext.Data.CreateTextLoader<PrimitiveInput>().Load(new MultiFileSource(null));

var pipeline = tensorFlowModel.ScoreTensorFlowModel(
inputColumnNames: new[] { "input1", "input2" },
outputColumnNames: new[] { "string_merge" });
var transformer = _mlContext.Model.CreatePredictionEngine<PrimitiveInput, PrimitiveOutput>(pipeline.Fit(dataview));

var input = new PrimitiveInput
{
input1 = "This is fine.",
input2 = "Thank you very much!."
};

var primitiveOutput = transformer.Predict(input);

Assert.Equal("This is fine.Thank you very much!.", primitiveOutput.string_merge);
}

[TensorFlowFact]
public void TensorFlowImageClassificationDefault()
{
Expand Down