diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 1617b17e20..f0bfb7e50a 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -75,27 +75,25 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str /// of kind , indicating the operation type of the node, and if that node has inputs in the graph, /// it contains metadata of kind , indicating the names of the input nodes. /// - /// An . - /// The name of the file containing the TensorFlow model. Currently only frozen model - /// format is supported. - public static Schema GetModelSchema(IExceptionContext ectx, string modelFile) + /// The environment to use. + /// Model to load. + public static Schema GetModelSchema(IHostEnvironment env, string modelPath) { - var bytes = File.ReadAllBytes(modelFile); - var session = LoadTFSession(ectx, bytes, modelFile); - return GetModelSchema(ectx, session.Graph); + var model = LoadTensorFlowModel(env, modelPath); + return GetModelSchema(env, model.Session.Graph); } /// /// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It - /// iterates over the columns of the returned by , + /// iterates over the columns of the returned by , /// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names. /// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type. /// - /// + /// Model to load. /// - public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile) + public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelPath) { - var schema = GetModelSchema(null, modelFile); + var schema = GetModelSchema(new MLContext(), modelPath); for (int i = 0; i < schema.Count; i++) { @@ -310,6 +308,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity) } } + /// + /// Load TensorFlow model into memory. + /// + /// The environment to use. + /// The model to load. + /// public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath) { var session = GetSession(env, modelPath); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 8bfae482c2..4699131680 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -606,8 +606,9 @@ public void TensorFlowTransformCifar() public void TensorFlowTransformCifarSavedModel() { var modelLocation = "cifar_saved_model"; - var mlContext = new MLContext(seed: 1, conc: 1); + var loadModelSchema = TensorFlowUtils.GetModelSchema(mlContext, modelLocation); + Assert.Equal(335, loadModelSchema.Count); var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation); var schema = tensorFlowModel.GetInputSchema(); Assert.True(schema.TryGetColumnIndex("Input", out int column));