Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom kwargs for model card in save_pretrained #2310

Merged
merged 5 commits into from
Jun 4, 2024

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Jun 3, 2024

Hi!

There is an option to customize the model card for Mixins, it would be great to support custom kwargs for a card also.

Here is an example:

import torch
from huggingface_hub import PyTorchModelHubMixin


MODEL_CARD_TEMPLATE = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
{{ card_data }}
---
{{ custom_data }}
"""

class Model(torch.nn.Module, PyTorchModelHubMixin, model_card_template=MODEL_CARD_TEMPLATE):
    pass
    
model = Model()
model_card_kwargs = {"custom_data": "This is an awesome model..."}
model.save_pretrained("test-model", model_card_kwargs=model_card_kwargs, push_to_hub=True)

Generated readme:

---
tags:
- pytorch_model_hub_mixin
- model_hub_mixin
---
This is awesome model...

Let me know what you think about this feature! If it's OK, I can add the necessary tests.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Wauplin
Copy link
Contributor

Wauplin commented Jun 3, 2024

The mixin currently have a set of very relevant model card metadata. For more customization, it's possible to overwrite the generate_model_card method. This method allows more flexibility and access to model's values/config. It return a ModelCard object. For example:

class Model(nn.Module, PytorchModelHubMixin):
    ...

    def generate_model_card(self) -> ModelCard:
        card = super(self).generate_model_card()
        card.data.custom_key = "custom value"
        return card

I'd prefer not to add too many features to the main way of defining tags/library_name/... to avoid complexity and promote this flexible method instead. WDYT?

@qubvel
Copy link
Member Author

qubvel commented Jun 3, 2024

The feature is inspired by this PR, as you may see I try to make it possible to pass a custom metrics dictionary and dataset name to be able to generate cards like this.
https://huggingface.co/qubvel-hf/oxford-pet-segmentation

model.save_pretrained('./my_model', metrics={'accuracy': 0.95}, dataset='my_dataset')

You might want to pass arguments not at the time of model class creation, but at the time of model saving, for example when the model is trained. I overridden the generate_model_card() method but I didn't find a way to pass any parameters from the save_pretrained() method. The only way I found is to set them as private attributes in save_pretrained and then remove them.
https://github.com/qubvel/segmentation_models.pytorch/blob/3d6da1d74636873372c265f300862a6a6d01777d/segmentation_models_pytorch/base/hub_mixin.py#L107
Now it looks like a dirty hack :(

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification and sorry for the confusion on my side. You are totally right, this makes sense to implement. I just reviewed your code and the integration looks good 👍

Could you add a small test in test_hub_mixin_pytorch.py with a new class that overwrites generate_model_card + pass data in save_pretrained and check you have the expected result. Thanks in advance!

(small nit about the integration you've mentioned: it would be good to use super(self).generate_model_card() in the integration and define model card template/library_name/tags/etc. in the class inheritance. This way models will be tagged as model_hub_mixin and pytorch_model_hub_mixin on the Hub which is important for us to track integrations and make sure to not break anything in the future if we update this Mixin.)

@qubvel
Copy link
Member Author

qubvel commented Jun 4, 2024

@Wauplin thank you for the feedback! I will update the integration, would be cool to include this PR changes, to get rid of private attributes :)

Could you add a small test in test_hub_mixin_pytorch.py with a new class that overwrites generate_model_card + pass data in save_pretrained and check you have the expected result. Thanks in advance!

Done! Please, have a look!

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! Thanks for the prompt update :) Test is clean, CI is green, let's merge 🎉

@Wauplin Wauplin merged commit 54515da into huggingface:main Jun 4, 2024
14 checks passed
@qubvel
Copy link
Member Author

qubvel commented Jun 4, 2024

Thanks for the quick review and merge 🤗

Wauplin pushed a commit that referenced this pull request Jun 14, 2024
* Support custom kwargs for model card in save_pretrained

* Fix failing test

* Fix test for pytorch mixin

* Add test for model_card_kwargs

* Fix style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants