-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensorflow GetModelSchema should support unfrozen models #2112
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,27 +75,25 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str | |
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph, | ||
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes. | ||
/// </summary> | ||
/// <param name="ectx">An <see cref="IExceptionContext"/>.</param> | ||
/// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model | ||
/// format is supported.</param> | ||
public static Schema GetModelSchema(IExceptionContext ectx, string modelFile) | ||
/// <param name="env">The environment to use.</param> | ||
/// <param name="modelPath">Model to load.</param> | ||
public static Schema GetModelSchema(IHostEnvironment env, string modelPath) | ||
{ | ||
var bytes = File.ReadAllBytes(modelFile); | ||
var session = LoadTFSession(ectx, bytes, modelFile); | ||
return GetModelSchema(ectx, session.Graph); | ||
var model = LoadTensorFlowModel(env, modelPath); | ||
return GetModelSchema(env, model.Session.Graph); | ||
} | ||
|
||
/// <summary> | ||
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It | ||
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext, string)"/>, | ||
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>, | ||
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names. | ||
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type. | ||
/// </summary> | ||
/// <param name="modelFile"></param> | ||
/// <param name="modelPath">Model to load.</param> | ||
/// <returns></returns> | ||
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile) | ||
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelPath) | ||
{ | ||
var schema = GetModelSchema(null, modelFile); | ||
var schema = GetModelSchema(new MLContext(), modelPath); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not correct. If you need an /cc @TomFinley There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only used as an IExceptionContext, so if we remove the check that env!=null on line 325, we can pass null. In reply to: 247276711 [](ancestors = 247276711) |
||
|
||
for (int i = 0; i < schema.Count; i++) | ||
{ | ||
|
@@ -310,6 +308,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity) | |
} | ||
} | ||
|
||
/// <summary> | ||
/// Load TensorFlow model into memory. | ||
/// </summary> | ||
/// <param name="env">The environment to use.</param> | ||
/// <param name="modelPath">The model to load.</param> | ||
/// <returns></returns> | ||
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath) | ||
{ | ||
var session = GetSession(env, modelPath); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -606,8 +606,9 @@ public void TensorFlowTransformCifar() | |
public void TensorFlowTransformCifarSavedModel() | ||
{ | ||
var modelLocation = "cifar_saved_model"; | ||
|
||
var mlContext = new MLContext(seed: 1, conc: 1); | ||
var loadModelSchema = TensorFlowUtils.GetModelSchema(mlContext, modelLocation); | ||
Assert.Equal(335, loadModelSchema.Count); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
am curious , what does this count signify ? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's all the nodes inside tensor flow graph. In reply to: 246930752 [](ancestors = 246930752) |
||
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation); | ||
var schema = tensorFlowModel.GetInputSchema(); | ||
Assert.True(schema.TryGetColumnIndex("Input", out int column)); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only used as an IExceptionContext, so instead of changing it here to IHostEnvironment, you can change LoadTensorFlowModel and GetSession to take an IExceptionContext.