-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: burn-import unary operators #548
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thank you!
Just two comments
burn-import/src/burn/node/base.rs
Outdated
Node::Linear(_) => "linear".to_string(), | ||
Node::BatchNorm(_) => "batch_norm".to_string(), | ||
Node::Equal(_) => "equal".to_string(), | ||
Node::Unary(unary) => unary.name.to_string(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of String
, can we define an enum
for all unary
operators? In this specific case we would have a double match
, the second one after Node::unary(node)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only used in logging, so the important part is having something clean, we already have an enum, but the Default
formating isn't pretty enough for printing. Maybe this should be refactoring into a Display
function instead of a "name".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that could be an idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a strum_macros::Display
derive that you can use. You can also specify serialization types, see https://docs.rs/strum/latest/strum/additional_attributes/index.html. We will basically use this:
#[derive(Debug, strum_macros::Display)]
#[strum(serialize_all = "snake_case")]
We are using strum in burn-import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now we can use this approach since it is already there, but I think we should try to use the less amount of dependencies as possible in my opinion. (when it's possible obviously)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree but we are already using this dependency in burn-import.
burn-import/src/burn/node/unary.rs
Outdated
pub struct UnaryNode { | ||
pub input: TensorType, | ||
pub output: TensorType, | ||
pub name: String, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define an enum instead of using a String?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a good idea.
We can also add tests that have been lost with this PR |
Change how the unary operations are coded in
burn-import
to reduce code duplication.