Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue 5020, allow ML.NET to load tf model with primitive input and output column #5468

Merged
merged 10 commits into from
Nov 5, 2020
22 changes: 10 additions & 12 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -503,20 +503,20 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
throw Host.Except("Variable length input columns not supported");

_isInputVector[i] = type is VectorDataViewType;
if (!_isInputVector[i])
throw Host.Except("Non-vector columns are not supported and should be loaded as vector columns of size 1");
vecType = (VectorDataViewType)type;
//if (!_isInputVector[i])
// throw Host.Except("Non-vector columns are not supported and should be loaded as vector columns of size 1");
Copy link
Contributor

@harishsk harishsk Nov 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not applicable anymore, please delete the lines. #Resolved

var expectedType = Tf2MlNetType(_parent.TFInputTypes[i]);
if (type.GetItemType() != expectedType)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], expectedType.ToString(), type.ToString());
var originalShape = _parent.TFInputShapes[i];
var shape = originalShape.dims;

var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
if (shape == null || (shape.Length == 0))
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
_fullySpecifiedShapes[i] = new TensorShape();
else
{
vecType = (VectorDataViewType)type;
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
// If the column is one dimension we make sure that the total size of the TF shape matches.
// Compute the total size of the known dimensions of the shape.
int valCount = 1;
Expand Down Expand Up @@ -729,11 +729,10 @@ public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape)
{
_srcgetter = input.GetGetter<T>(input.Schema[colIndex]);
_tfShape = tfShape;
long size = 0;
long size = 1;
_position = 0;
if (tfShape.dims.Length != 0)
if (tfShape.dims != null && tfShape.dims.Length != 0)
{
size = 1;
foreach (var dim in tfShape.dims)
size *= dim;
}
Expand All @@ -744,8 +743,7 @@ public Tensor GetTensor()
{
var scalar = default(T);
_srcgetter(ref scalar);
var tensor = new Tensor(new[] { scalar });
tensor.set_shape(_tfShape);
var tensor = TensorFlowUtils.CastDataAndReturnAsTensor(scalar);
return tensor;
}

Expand Down Expand Up @@ -928,8 +926,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
var input = _options.InputColumns[i];
if (!inputSchema.TryFindColumn(input, out var col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
if (!(col.Kind == SchemaShape.Column.VectorKind.Vector))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
//if (!(col.Kind == SchemaShape.Column.VectorKind.Vector))
Copy link
Contributor

@harishsk harishsk Nov 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete if not applicable. #Resolved

// throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
var expectedType = Tf2MlNetType(_tfInputTypes[i]);
if (col.ItemType != expectedType)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
Expand Down
30 changes: 30 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,36 @@ internal static Tensor CastDataAndReturnAsTensor<T>(T[] data, TensorShape tfShap
return new Tensor(new NDArray(data, tfShape));
}

internal static Tensor CastDataAndReturnAsTensor<T>(T data)
{
if (typeof(T) == typeof(sbyte))
return new Tensor((sbyte)(object)data, TF_DataType.TF_INT8);
else if (typeof(T) == typeof(long))
return new Tensor((long)(object)data, TF_DataType.TF_INT64);
else if (typeof(T) == typeof(Int32))
return new Tensor((Int32)(object)data, TF_DataType.TF_INT32);
else if (typeof(T) == typeof(Int16))
return new Tensor((Int16)(object)data, TF_DataType.TF_INT16);
else if (typeof(T) == typeof(byte))
return new Tensor((byte)(object)data, TF_DataType.TF_UINT8);
else if (typeof(T) == typeof(ulong))
return new Tensor((ulong)(object)data, TF_DataType.TF_UINT64);
else if (typeof(T) == typeof(UInt32))
return new Tensor((UInt32)(object)data, TF_DataType.TF_UINT32);
else if (typeof(T) == typeof(UInt16))
return new Tensor((UInt16)(object)data, TF_DataType.TF_UINT16);
else if (typeof(T) == typeof(bool))
return new Tensor((bool)(object)data, TF_DataType.TF_BOOL);
else if (typeof(T) == typeof(float))
return new Tensor((float)(object)data, TF_DataType.TF_FLOAT);
else if (typeof(T) == typeof(double))
return new Tensor((double)(object)data, TF_DataType.TF_DOUBLE);
else if (typeof(T) == typeof(ReadOnlyMemory<char>))
return new Tensor(data.ToString());

throw new ArgumentException($"Unsupported data type of {typeof(T)} to convert to Tensor.");
}

/// <summary>
/// Use the runner class to easily configure inputs, outputs and targets to be passed to the session runner.
/// </summary>
Expand Down