Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvNeXt models #16421

Merged
merged 22 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ce8d99b
feat: initial implementation of convnext.
sayakpaul Apr 15, 2022
2470a0e
chore: added config and cleaned up some code.
sayakpaul Apr 16, 2022
4d49ca2
fix: initial convnext implementation.
sayakpaul Apr 16, 2022
bfd1af8
chore: added doc instantiator.
sayakpaul Apr 17, 2022
9844e27
chore: applied initial PR feedback.
sayakpaul Apr 17, 2022
3b990ae
chore: Block -> ConvNeXtBlock.
sayakpaul Apr 18, 2022
2c1b1e4
chore: corrected repo link and indentation.
sayakpaul Apr 18, 2022
dfe6a6b
feat: added config to convnext block, simplied staging.
sayakpaul Apr 19, 2022
8a150fe
chore: address luke's review feedback.
sayakpaul Apr 30, 2022
e5171b4
chore: add convnext models to init.
sayakpaul May 2, 2022
6b9f091
add: tests, _init_ changes.
sayakpaul May 2, 2022
eaba0d1
chore: self.activation.
sayakpaul May 3, 2022
65de069
changes to build files.
sayakpaul May 6, 2022
78e7558
add: missing comma.
sayakpaul May 6, 2022
e8ee816
feat: application of normalization for 3-channel inputs.
sayakpaul May 6, 2022
d3ee960
fix: reconstruction of convnext models from config.
sayakpaul May 6, 2022
e691392
feat: convnext with functional api.
sayakpaul May 6, 2022
dd84bc3
Merge pull request #2 from sayakpaul/feat/convnext-functional
sayakpaul May 6, 2022
4b24dc0
chore: spacing fix/
sayakpaul May 6, 2022
ea5f829
Merge pull request #3 from sayakpaul/feat/convnext-functional
sayakpaul May 6, 2022
aa7b4c3
chore: datatype to float.
sayakpaul May 9, 2022
c9e5b0d
chore: change weight path.
sayakpaul May 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions keras/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ exports_files(
keras_packages = [
"keras",
"keras.activations",
"keras.applications.convnext",
"keras.applications.densenet",
"keras.applications.efficientnet",
"keras.applications.efficientnet_v2",
Expand Down Expand Up @@ -181,7 +182,7 @@ gen_api_init_files(
"//keras",
"//:expect_tensorflow_installed",
],
packages = keras_packages,
packages=keras_packages,
)

gen_api_init_files(
Expand All @@ -195,5 +196,5 @@ gen_api_init_files(
"//keras",
"//:expect_tensorflow_installed",
],
packages = keras_packages,
packages=keras_packages,
)
2 changes: 2 additions & 0 deletions keras/api/api_init_files.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ KERAS_API_INIT_FILES = [
"keras/__internal__/utils/__init__.py",
"keras/activations/__init__.py",
"keras/applications/__init__.py",
"keras/applications/convnext/__init__.py",
"keras/applications/densenet/__init__.py",
"keras/applications/efficientnet/__init__.py",
"keras/applications/efficientnet_v2/__init__.py",
Expand Down Expand Up @@ -85,6 +86,7 @@ KERAS_API_INIT_FILES_V1 = [
"keras/__internal__/legacy/rnn_cell/__init__.py",
"keras/activations/__init__.py",
"keras/applications/__init__.py",
"keras/applications/convnext/__init__.py",
"keras/applications/densenet/__init__.py",
"keras/applications/efficientnet/__init__.py",
"keras/applications/efficientnet_v2/__init__.py",
Expand Down
19 changes: 19 additions & 0 deletions keras/applications/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ py_library(
srcs = [
"__init__.py",
"densenet.py",
"convnext.py",
"efficientnet.py",
"efficientnet_v2.py",
"imagenet_utils.py",
Expand Down Expand Up @@ -270,6 +271,24 @@ tf_py_test(
],
)

tf_py_test(
name = "applications_load_weight_test_convnext",
size = "large",
srcs = ["applications_load_weight_test.py"],
args = ["--module=convnext"],
main = "applications_load_weight_test.py",
tags = [
"no_oss",
"no_pip",
],
deps = [
":applications",
"//:expect_absl_installed",
"//:expect_tensorflow_installed",
"//keras/preprocessing",
],
)

tf_py_test(
name = "applications_load_weight_test_densenet",
size = "large",
Expand Down
6 changes: 6 additions & 0 deletions keras/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
"""Keras Applications are premade architectures with pre-trained weights."""
# pylint: disable=g-bad-import-order

from keras.applications.convnext import ConvNeXtTiny
from keras.applications.convnext import ConvNeXtSmall
from keras.applications.convnext import ConvNeXtBase
from keras.applications.convnext import ConvNeXtLarge
from keras.applications.convnext import ConvNeXtXLarge

from keras.applications.densenet import DenseNet121
from keras.applications.densenet import DenseNet169
from keras.applications.densenet import DenseNet201
Expand Down
5 changes: 5 additions & 0 deletions keras/applications/applications_load_weight_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl.testing import parameterized
import numpy as np

from keras.applications import convnext
from keras.applications import densenet
from keras.applications import efficientnet
from keras.applications import efficientnet_v2
Expand Down Expand Up @@ -55,6 +56,10 @@
'mobilenet_v2': (mobilenet_v2, [mobilenet_v2.MobileNetV2]),
'mobilenet_v3_small': (mobilenet_v3, [mobilenet_v3.MobileNetV3Small]),
'mobilenet_v3_large': (mobilenet_v3, [mobilenet_v3.MobileNetV3Large]),
'convnext':
(convnext,
[convnext.ConvNeXtTiny, convnext.ConvNeXtSmall, convnext.ConvNeXtBase,
convnext.ConvNeXtLarge, convnext.ConvNeXtXLarge]),
'densenet':
(densenet,
[densenet.DenseNet121, densenet.DenseNet169, densenet.DenseNet201]),
Expand Down
14 changes: 13 additions & 1 deletion keras/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import parameterized
from keras import backend
from keras.applications import convnext
from keras.applications import densenet
from keras.applications import efficientnet
from keras.applications import efficientnet_v2
Expand All @@ -32,6 +33,7 @@
from keras.applications import vgg16
from keras.applications import vgg19
from keras.applications import xception
from keras import utils
import tensorflow.compat.v2 as tf

MODEL_LIST_NO_NASNET = [(resnet.ResNet50, 2048), (resnet.ResNet101, 2048),
Expand All @@ -45,6 +47,11 @@
(mobilenet_v2.MobileNetV2, 1280),
(mobilenet_v3.MobileNetV3Small, 576),
(mobilenet_v3.MobileNetV3Large, 960),
(convnext.ConvNeXtTiny, 768),
(convnext.ConvNeXtSmall, 768),
(convnext.ConvNeXtBase, 1024),
(convnext.ConvNeXtLarge, 1536),
(convnext.ConvNeXtXLarge, 2048),
(densenet.DenseNet121, 1024),
(densenet.DenseNet169, 1664),
(densenet.DenseNet201, 1920),
Expand Down Expand Up @@ -124,7 +131,12 @@ def test_application_base(self, app, _):
model = app(weights=None)
# Can be serialized and deserialized
config = model.get_config()
reconstructed_model = model.__class__.from_config(config)
if "ConvNeXt" in app.__name__:
custom_objects = {"LayerScale": convnext.LayerScale}
with utils.custom_object_scope(custom_objects):
reconstructed_model = model.__class__.from_config(config)
else:
reconstructed_model = model.__class__.from_config(config)
self.assertEqual(len(model.weights), len(reconstructed_model.weights))
backend.clear_session()

Expand Down
Loading