-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix issue 5020, allow ML.NET to load tf model with primitive input and output column #5468
Merged
Merged
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
bdf36b7
handle exception during GetNextPipeline for AutoML
frank-dong-ms 281305d
take comments
frank-dong-ms 4d3bc6a
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
frank-dong-ms 941b9f2
Enable TesnflowTransformer take primitive type as input column
frank-dong-ms a0f7eee
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
frank-dong-ms 8806213
undo unnecessary changes
frank-dong-ms af49c75
add test
frank-dong-ms bcf68f0
update on test
frank-dong-ms 5fcbc96
remove unnecessary line
frank-dong-ms ca8dac3
take comments
frank-dong-ms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -503,20 +503,20 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : | |
throw Host.Except("Variable length input columns not supported"); | ||
|
||
_isInputVector[i] = type is VectorDataViewType; | ||
if (!_isInputVector[i]) | ||
throw Host.Except("Non-vector columns are not supported and should be loaded as vector columns of size 1"); | ||
vecType = (VectorDataViewType)type; | ||
//if (!_isInputVector[i]) | ||
// throw Host.Except("Non-vector columns are not supported and should be loaded as vector columns of size 1"); | ||
var expectedType = Tf2MlNetType(_parent.TFInputTypes[i]); | ||
if (type.GetItemType() != expectedType) | ||
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], expectedType.ToString(), type.ToString()); | ||
var originalShape = _parent.TFInputShapes[i]; | ||
var shape = originalShape.dims; | ||
|
||
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray(); | ||
if (shape == null || (shape.Length == 0)) | ||
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims); | ||
_fullySpecifiedShapes[i] = new TensorShape(); | ||
else | ||
{ | ||
vecType = (VectorDataViewType)type; | ||
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray(); | ||
// If the column is one dimension we make sure that the total size of the TF shape matches. | ||
// Compute the total size of the known dimensions of the shape. | ||
int valCount = 1; | ||
|
@@ -729,11 +729,10 @@ public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape) | |
{ | ||
_srcgetter = input.GetGetter<T>(input.Schema[colIndex]); | ||
_tfShape = tfShape; | ||
long size = 0; | ||
long size = 1; | ||
_position = 0; | ||
if (tfShape.dims.Length != 0) | ||
if (tfShape.dims != null && tfShape.dims.Length != 0) | ||
{ | ||
size = 1; | ||
foreach (var dim in tfShape.dims) | ||
size *= dim; | ||
} | ||
|
@@ -744,8 +743,7 @@ public Tensor GetTensor() | |
{ | ||
var scalar = default(T); | ||
_srcgetter(ref scalar); | ||
var tensor = new Tensor(new[] { scalar }); | ||
tensor.set_shape(_tfShape); | ||
var tensor = TensorFlowUtils.CastDataAndReturnAsTensor(scalar); | ||
return tensor; | ||
} | ||
|
||
|
@@ -928,8 +926,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) | |
var input = _options.InputColumns[i]; | ||
if (!inputSchema.TryFindColumn(input, out var col)) | ||
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); | ||
if (!(col.Kind == SchemaShape.Column.VectorKind.Vector)) | ||
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString()); | ||
//if (!(col.Kind == SchemaShape.Column.VectorKind.Vector)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please delete if not applicable. #Resolved |
||
// throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString()); | ||
var expectedType = Tf2MlNetType(_tfInputTypes[i]); | ||
if (col.ItemType != expectedType) | ||
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is not applicable anymore, please delete the lines. #Resolved