diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
index 1617b17e20..f0bfb7e50a 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
@@ -75,27 +75,25 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str
/// of kind , indicating the operation type of the node, and if that node has inputs in the graph,
/// it contains metadata of kind , indicating the names of the input nodes.
///
- /// An .
- /// The name of the file containing the TensorFlow model. Currently only frozen model
- /// format is supported.
- public static Schema GetModelSchema(IExceptionContext ectx, string modelFile)
+ /// The environment to use.
+ /// Model to load.
+ public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
{
- var bytes = File.ReadAllBytes(modelFile);
- var session = LoadTFSession(ectx, bytes, modelFile);
- return GetModelSchema(ectx, session.Graph);
+ var model = LoadTensorFlowModel(env, modelPath);
+ return GetModelSchema(env, model.Session.Graph);
}
///
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
- /// iterates over the columns of the returned by ,
+ /// iterates over the columns of the returned by ,
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
///
- ///
+ /// Model to load.
///
- public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile)
+ public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelPath)
{
- var schema = GetModelSchema(null, modelFile);
+ var schema = GetModelSchema(new MLContext(), modelPath);
for (int i = 0; i < schema.Count; i++)
{
@@ -310,6 +308,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
}
}
+ ///
+ /// Load TensorFlow model into memory.
+ ///
+ /// The environment to use.
+ /// The model to load.
+ ///
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
{
var session = GetSession(env, modelPath);
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 8bfae482c2..4699131680 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -606,8 +606,9 @@ public void TensorFlowTransformCifar()
public void TensorFlowTransformCifarSavedModel()
{
var modelLocation = "cifar_saved_model";
-
var mlContext = new MLContext(seed: 1, conc: 1);
+ var loadModelSchema = TensorFlowUtils.GetModelSchema(mlContext, modelLocation);
+ Assert.Equal(335, loadModelSchema.Count);
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation);
var schema = tensorFlowModel.GetInputSchema();
Assert.True(schema.TryGetColumnIndex("Input", out int column));