Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Summary function for model summary #1015

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,76 @@ be used for aesthetic purposes, or by recovering Python users.
macro jit(ex)
esc(ex)
end

############################################################################
# Summary of the model #
############################################################################
struct Trainable_layer
layer_name::Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the same style as the rest of the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize for that, this was my first PR in Flux. So I wasn't much accustomed with the style that time. For future, I will keep that in mind

num_params::Int64
end

struct Summary
layers::Array{Trainable_layer}
total_params::Int64
end

"""
print_summary(Summary)
Helper function for printing the summary of the model in a proper format.
"""
function print_summary(summ)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do end up with a struct for the summary, it's better to instead overload the show method here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although what advantage does a struct have here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While discussing about it on slack, someone suggested that it would be better to provide the user with the information collected through the function and let him decide what to do with the information like to analyze the no. of parameters in certain fixed limit of layers. That was the sole reason to use a struct here.

print("+++++++++++++++++++++++++++++Model Summary++++++++++++++++++++++++++++++++++++\n")
for i in 1:length(summ.layers)
layer = summ.layers[i]
print("")
print("Trainable Layer_$i: ",layer.layer_name," ")
print("Num_Trainable_Parameters: ",layer.num_params,"\n\n")
end
print("Total Number of Trainable Parameters: ",summ.total_params,"\n")
end


""""
Summary(model)
Call `Summary(model)` to get the brief summary of the model. It return struct of two fields: layers and total_params.

Fields/Parameters:

1.`Summary(model).layers`: gives a list of trainable layers in the model, along with the number of parameters in each layer.
It returns an array of struct `Trainable_layers`, which have fields `layer_name` and `num_params`.<br>
<br>
2.`Summary(model).total_params`: returns the total number of trainable parameters used in the model

```julia
#Example
julia> model = Chain(Dense(10,64), Dense(64,1))
Chain(Dense(10, 64), Dense(64, 1))

julia> Flux.Summary(model);
+++++++++++++++++++++++++++++Model Summary++++++++++++++++++++++++++++++++++++
Trainable Layer_1: Dense(10, 64) Num_Trainable_Parameters: 704

Trainable Layer_2: Dense(64, 1) Num_Trainable_Parameters: 65

Total Number of Trainable Parameters: 769
```
"""
function Summary(model)
layers_vec =[]
layers = model.layers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably don't need to maintain a copy of existing information. Layers with no params can probably be ignored, although a placeholder suggesting the existence of one would be good

Total_trainable_par =0
for layer in layers
try
x = sum(length,params(layer))
push!(layers_vec,Trainable_layer(layer,x))
Total_trainable_par+=x
catch error
end
end
Summ = Summary(layers_vec,Total_trainable_par)
print_summary(Summ)
return Summ
end
#######################################################################################