-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weights densenet #1855
Weights densenet #1855
Conversation
x = keras.layers.Rescaling(scale=1 / 255.0)(image_input) | ||
x = keras.layers.Normalization( | ||
axis=channel_axis, | ||
mean=(0.485, 0.456, 0.406), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need a normalization layer in addition to rescaling, and where are these constants coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are from the standard Imagenet mean and std, which is also mentioned here https://huggingface.co/timm/densenet201.tv_in1k/blob/main/config.json.
In the Torch model, they don't have the normalization applied, we need to apply this for the channel axis.
Without the normalization to timm model, the difference is 0.6 in 201 variation, after applying normalization, the difference dropped down to 0.3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is how timm is applying the configs, transforms = timm.data.create_transform(**data_config, is_training=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Let me chat with @fchollet about this, but this is making me think we might want to push rescaling and this normalization to our preprocessing layers, and remove include_rescaling
everywhere. That seems conceptually cleaner, and it will give us a good way to propagate these as config instead of hardcoding them.
Let me chat with him.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, we may also have to check the
"interpolation": "bicubic",
"crop_pct": 0.95,
``` which timm is using in many of the models so that the results comes almost closer to the timm/ benchmark result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to match crop percent. Better to do simple rescaling of user image than match their crop. Interpolation matching might be worth doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! The only thing I don't get is the rescaling. Left a question below.
Also, looks like some legit test failures to look into. |
#1859 is in, so this just needs to base over those changes, reupload kaggle presets, and fix these failing tests. |
Uploaded the Latest weights and configurations to DenseNet Kaggle https://www.kaggle.com/models/kerashub/densenet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!! LGTM!
This reverts commit f67b4db.
Tested Colab: