Skip to content

Commit

Permalink
Consistent saved_model output format (ultralytics#7032)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrinalJain17 authored Mar 18, 2022
1 parent 4b81e09 commit 0d5433a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def export_saved_model(model, im, file, dynamic,
m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
tfm.__call__ = tf.function(lambda x: frozen_func(x)[0], [spec])
tfm.__call__(im)
tf.saved_model.save(
tfm,
Expand Down
2 changes: 1 addition & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy()
else: # Lite or Edge TPU
Expand Down

0 comments on commit 0d5433a

Please sign in to comment.