-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Summary function for model summary #1015
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -200,3 +200,76 @@ be used for aesthetic purposes, or by recovering Python users. | |
macro jit(ex) | ||
esc(ex) | ||
end | ||
|
||
############################################################################ | ||
# Summary of the model # | ||
############################################################################ | ||
struct Trainable_layer | ||
layer_name::Any | ||
num_params::Int64 | ||
end | ||
|
||
struct Summary | ||
layers::Array{Trainable_layer} | ||
total_params::Int64 | ||
end | ||
|
||
""" | ||
print_summary(Summary) | ||
Helper function for printing the summary of the model in a proper format. | ||
""" | ||
function print_summary(summ) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do end up with a struct for the summary, it's better to instead overload the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although what advantage does a struct have here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While discussing about it on slack, someone suggested that it would be better to provide the user with the information collected through the function and let him decide what to do with the information like to analyze the no. of parameters in certain fixed limit of layers. That was the sole reason to use a struct here. |
||
print("+++++++++++++++++++++++++++++Model Summary++++++++++++++++++++++++++++++++++++\n") | ||
for i in 1:length(summ.layers) | ||
layer = summ.layers[i] | ||
print("") | ||
print("Trainable Layer_$i: ",layer.layer_name," ") | ||
print("Num_Trainable_Parameters: ",layer.num_params,"\n\n") | ||
end | ||
print("Total Number of Trainable Parameters: ",summ.total_params,"\n") | ||
end | ||
|
||
|
||
"""" | ||
Summary(model) | ||
Call `Summary(model)` to get the brief summary of the model. It return struct of two fields: layers and total_params. | ||
|
||
Fields/Parameters: | ||
|
||
1.`Summary(model).layers`: gives a list of trainable layers in the model, along with the number of parameters in each layer. | ||
It returns an array of struct `Trainable_layers`, which have fields `layer_name` and `num_params`.<br> | ||
<br> | ||
2.`Summary(model).total_params`: returns the total number of trainable parameters used in the model | ||
|
||
```julia | ||
#Example | ||
julia> model = Chain(Dense(10,64), Dense(64,1)) | ||
Chain(Dense(10, 64), Dense(64, 1)) | ||
|
||
julia> Flux.Summary(model); | ||
+++++++++++++++++++++++++++++Model Summary++++++++++++++++++++++++++++++++++++ | ||
Trainable Layer_1: Dense(10, 64) Num_Trainable_Parameters: 704 | ||
|
||
Trainable Layer_2: Dense(64, 1) Num_Trainable_Parameters: 65 | ||
|
||
Total Number of Trainable Parameters: 769 | ||
``` | ||
""" | ||
function Summary(model) | ||
layers_vec =[] | ||
layers = model.layers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably don't need to maintain a copy of existing information. Layers with no params can probably be ignored, although a placeholder suggesting the existence of one would be good |
||
Total_trainable_par =0 | ||
for layer in layers | ||
try | ||
x = sum(length,params(layer)) | ||
push!(layers_vec,Trainable_layer(layer,x)) | ||
Total_trainable_par+=x | ||
catch error | ||
end | ||
end | ||
Summ = Summary(layers_vec,Total_trainable_par) | ||
print_summary(Summ) | ||
return Summ | ||
end | ||
####################################################################################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow the same style as the rest of the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I apologize for that, this was my first PR in Flux. So I wasn't much accustomed with the style that time. For future, I will keep that in mind