Skip to content

Commit

Permalink
all trainable / loadable
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 9, 2024
1 parent 59f1c30 commit 6ea0b98
Show file tree
Hide file tree
Showing 28 changed files with 41 additions and 118 deletions.
15 changes: 0 additions & 15 deletions doctr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@
logging.info("Disabling PyTorch because USE_TF is set")
_torch_available = False

# Compatibility fix to make sure tensorflow.keras stays at Keras 2
if "TF_USE_LEGACY_KERAS" not in os.environ:
os.environ["TF_USE_LEGACY_KERAS"] = "1"

elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
raise ValueError(
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
)


def ensure_keras_v2() -> None: # pragma: no cover
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
os.environ["TF_USE_LEGACY_KERAS"] = "1"


if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
Expand Down Expand Up @@ -79,7 +65,6 @@ def ensure_keras_v2() -> None: # pragma: no cover
_tf_available = False
else:
logging.info(f"TensorFlow version {_tf_version} available.")
ensure_keras_v2()
import tensorflow as tf

# Enable eager execution - this is required for some models to work properly
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
"url": None,
},
}

Expand Down
12 changes: 6 additions & 6 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,42 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
"url": None,
},
"mobilenet_v3_large_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
"url": None,
},
"mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
"url": None,
},
"mobilenet_v3_small_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
"url": None,
},
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (128, 128, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
"url": None,
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (512, 512, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
"url": None,
},
}

Expand Down
10 changes: 5 additions & 5 deletions doctr/models/classification/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,35 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
"url": None,
},
"resnet31": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
"url": None,
},
"resnet34": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
"url": None,
},
"resnet50": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
"url": None,
},
"resnet34_wide": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
"url": None,
},
}

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
"url": None,
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
"url": None,
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
"url": None,
},
}

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"std": (1.0, 1.0, 1.0),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
"url": None,
},
}

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
"url": None,
},
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
"url": None,
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
"url": None,
},
"db_mobilenet_v3_large": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
"url": None,
},
}

Expand Down
9 changes: 3 additions & 6 deletions doctr/models/detection/fast/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@
"input_shape": (1024, 1024, 3),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
"url": None,
},
"fast_small": {
"input_shape": (1024, 1024, 3),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
"url": None,
},
"fast_base": {
"input_shape": (1024, 1024, 3),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
"url": None,
},
}

Expand Down Expand Up @@ -342,9 +342,6 @@ def _fast(
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
)

# Build the model for reparameterization to access the layers
_ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)

return model


Expand Down
6 changes: 3 additions & 3 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
"url": None,
},
"linknet_resnet34": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
"url": None,
},
"linknet_resnet50": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
"url": None,
},
}

Expand Down
6 changes: 1 addition & 5 deletions doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
if is_torch_available():
import torch
elif is_tf_available():
import tensorflow as tf
pass

__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]

Expand Down Expand Up @@ -76,8 +76,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
torch.save(model.state_dict(), weights_path)
elif is_tf_available():
weights_path = save_directory / "tf_model.weights.h5"
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)
model.save_weights(str(weights_path))

config_path = save_directory / "config.json"
Expand Down Expand Up @@ -229,8 +227,6 @@ def from_hub(repo_id: str, **kwargs: Any):
model.load_state_dict(state_dict)
else: # tf
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)
model.load_weights(weights)

return model
8 changes: 4 additions & 4 deletions doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["legacy_french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
"vocab": VOCABS["french"],
"url": None,
},
"crnn_mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
"url": None,
},
"crnn_mobilenet_v3_large": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
"url": None,
},
}

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0",
"url": None,
},
}

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
"url": None,
},
}

Expand Down
6 changes: 2 additions & 4 deletions doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
"url": None,
},
}

Expand Down Expand Up @@ -170,9 +170,7 @@ def call(
for t in range(self.max_length + 1): # 32
if t == 0:
# step to init the first states of the LSTMCell
states = self.lstm_cells.get_initial_state(
inputs=None, batch_size=features.shape[0], dtype=features.dtype
)
states = self.lstm_cells.get_initial_state(batch_size=features.shape[0])
prev_symbol = holistic
elif t == 1:
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
"url": None,
},
"vitstr_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
"url": None,
},
}

Expand Down
6 changes: 1 addition & 5 deletions doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ def load_pretrained_params(
else:
archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)

# Build the model
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)

# Load weights
model.load_weights(archive_path, skip_mismatch=skip_mismatch)

Expand Down Expand Up @@ -125,7 +121,7 @@ class IntermediateLayerGetter(Model):
"""

def __init__(self, model: Model, layer_names: List[str]) -> None:
intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
intermediate_fmaps = [model.get_layer(layer_name)._inbound_nodes[0].outputs[0] for layer_name in layer_names]
super().__init__(model.input, outputs=intermediate_fmaps)

def __repr__(self) -> str:
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ dependencies = [
tf = [
# cf. https://github.com/mindee/doctr/pull/1461
"tensorflow>=2.15.0,<3.0.0",
"tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility
"tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0
]
torch = [
Expand Down Expand Up @@ -98,7 +97,6 @@ dev = [
# Tensorflow
# cf. https://github.com/mindee/doctr/pull/1461
"tensorflow>=2.15.0,<3.0.0",
"tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility
"tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0
# PyTorch
"torch>=1.12.0,<3.0.0",
Expand Down
Loading

0 comments on commit 6ea0b98

Please sign in to comment.