Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: burn-import unary operators #548

Merged
merged 4 commits into from
Jul 27, 2023

Conversation

nathanielsimard
Copy link
Member

Change how the unary operations are coded in burn-import to reduce code duplication.

Copy link
Collaborator

@Luni-4 Luni-4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thank you!

Just two comments

Node::Linear(_) => "linear".to_string(),
Node::BatchNorm(_) => "batch_norm".to_string(),
Node::Equal(_) => "equal".to_string(),
Node::Unary(unary) => unary.name.to_string(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of String, can we define an enum for all unary operators? In this specific case we would have a double match, the second one after Node::unary(node)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used in logging, so the important part is having something clean, we already have an enum, but the Default formating isn't pretty enough for printing. Maybe this should be refactoring into a Display function instead of a "name".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that could be an idea

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nathanielsimard @Luni-4

There is a strum_macros::Display
derive that you can use. You can also specify serialization types, see https://docs.rs/strum/latest/strum/additional_attributes/index.html. We will basically use this:

#[derive(Debug, strum_macros::Display)]
#[strum(serialize_all = "snake_case")]

We are using strum in burn-import.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antimora @nathanielsimard

For now we can use this approach since it is already there, but I think we should try to use the less amount of dependencies as possible in my opinion. (when it's possible obviously)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but we are already using this dependency in burn-import.

pub struct UnaryNode {
pub input: TensorType,
pub output: TensorType,
pub name: String,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define an enum instead of using a String?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good idea.

@Luni-4
Copy link
Collaborator

Luni-4 commented Jul 26, 2023

We can also add tests that have been lost with this PR

@nathanielsimard nathanielsimard merged commit f0a7135 into main Jul 27, 2023
@nathanielsimard nathanielsimard deleted the refactor/import/unary-ops branch July 27, 2023 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants