diff --git a/src/utils.jl b/src/utils.jl index 2dba21c740..ec8ff06b8f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -200,3 +200,76 @@ be used for aesthetic purposes, or by recovering Python users. macro jit(ex) esc(ex) end + +############################################################################ + # Summary of the model # +############################################################################ +struct Trainable_layer + layer_name::Any + num_params::Int64 +end + +struct Summary + layers::Array{Trainable_layer} + total_params::Int64 +end + +""" + print_summary(Summary) +Helper function for printing the summary of the model in a proper format. +""" +function print_summary(summ) + + print("+++++++++++++++++++++++++++++Model Summary++++++++++++++++++++++++++++++++++++\n") + for i in 1:length(summ.layers) + layer = summ.layers[i] + print("") + print("Trainable Layer_$i: ",layer.layer_name," ") + print("Num_Trainable_Parameters: ",layer.num_params,"\n\n") + end + print("Total Number of Trainable Parameters: ",summ.total_params,"\n") +end + + +"""" + Summary(model) +Call `Summary(model)` to get the brief summary of the model. It return struct of two fields: layers and total_params. + +Fields/Parameters: + +1.`Summary(model).layers`: gives a list of trainable layers in the model, along with the number of parameters in each layer. + It returns an array of struct `Trainable_layers`, which have fields `layer_name` and `num_params`.
+
+2.`Summary(model).total_params`: returns the total number of trainable parameters used in the model + +```julia +#Example +julia> model = Chain(Dense(10,64), Dense(64,1)) +Chain(Dense(10, 64), Dense(64, 1)) + +julia> Flux.Summary(model); ++++++++++++++++++++++++++++++Model Summary++++++++++++++++++++++++++++++++++++ +Trainable Layer_1: Dense(10, 64) Num_Trainable_Parameters: 704 + +Trainable Layer_2: Dense(64, 1) Num_Trainable_Parameters: 65 + +Total Number of Trainable Parameters: 769 +``` +""" +function Summary(model) + layers_vec =[] + layers = model.layers + Total_trainable_par =0 + for layer in layers + try + x = sum(length,params(layer)) + push!(layers_vec,Trainable_layer(layer,x)) + Total_trainable_par+=x + catch error + end + end + Summ = Summary(layers_vec,Total_trainable_par) + print_summary(Summ) + return Summ +end +#######################################################################################