Skip to content

Commit

Permalink
fix: updated alexnet weights
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 11, 2024
1 parent 6b65a95 commit 0d9f03f
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 23 deletions.
6 changes: 3 additions & 3 deletions Artifacts.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[alexnet]
git-tree-sha1 = "8904a6756649aa4cd264328430b829dafde95645"
git-tree-sha1 = "f1739363d54a358cae904133699a93eca7b7a028"
lazy = true

[[alexnet.download]]
sha256 = "e20107404aba1c2c0ed3fad4314033a2fa600cdc0c55d03bc1bfe4f8e5031105"
url = "https://github.com/LuxDL/Lux.jl/releases/download/weights/alexnet.tar.gz"
sha256 = "feb3e1600179ba00b72a68759c7f3b12f400f6d59b28ac72b48614cbafa187d8"
url = "https://huggingface.co/LuxDL/alexnet/resolve/2c48051ecb131d38f2209470cdda70a343289db1/alexnet.tar.gz"

# [resnet101]
# git-tree-sha1 = "6c9143d40950726405b88db0cc021fa1dcbc0896"
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand Down Expand Up @@ -51,6 +52,7 @@ ConcreteStructs = "0.2.3"
DataInterpolations = "< 5.3"
DynamicExpressions = "0.16, 0.17, 0.18, 0.19"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
GPUArraysCore = "0.1.6"
JLD2 = "0.4.48, 0.5"
LazyArtifacts = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/vision.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Native Lux Models

```@docs
Vision.AlexNet
Vision.VGG
Vision.VisionTransformer
```
Expand All @@ -14,7 +15,6 @@ Vision.VisionTransformer
You need to load `Metalhead` before using these models.

```@docs
Vision.AlexNet
Vision.ConvMixer
Vision.DenseNet
Vision.GoogLeNet
Expand Down
5 changes: 0 additions & 5 deletions ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ using Boltz: Boltz, Utils, Vision

Utils.is_extension_loaded(::Val{:Metalhead}) = true

function Vision.AlexNetMetalhead()
model = FromFluxAdaptor()(Metalhead.AlexNet().layers)
return :alexnet, model
end

function Vision.ResNetMetalhead(depth::Int)
@argcheck depth in (18, 34, 50, 101, 152)
model = FromFluxAdaptor()(Metalhead.ResNet(depth).layers)
Expand Down
19 changes: 19 additions & 0 deletions src/initialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module InitializeModels

using ArgCheck: @argcheck
using Artifacts: Artifacts, @artifact_str
using Functors: fmap
using LazyArtifacts: LazyArtifacts
using Random: Random

Expand Down Expand Up @@ -30,4 +31,22 @@ end

function load_using_jld2_internal end

struct SerializedRNG end

function remove_rng_from_structure(x)
return fmap(x) do xᵢ
xᵢ isa Random.AbstractRNG && return SerializedRNG()
return xᵢ
end
end

loadparameters(x) = x

function loadstates(x)
return fmap(x) do xᵢ
xᵢ isa SerializedRNG && return Random.default_rng()
return xᵢ
end
end

end
5 changes: 4 additions & 1 deletion src/vision/Vision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ abstract type AbstractLuxVisionLayer <: AbstractLuxWrapperLayer{:layer} end

for op in (:states, :parameters)
fname = Symbol(:initial, op)
fname_load = Symbol(:load, op)
@eval function LuxCore.$(fname)(rng::AbstractRNG, model::AbstractLuxVisionLayer)
if hasfield(typeof(model), :pretrained) && model.pretrained
path = InitializeModels.get_pretrained_weights_path(model.pretrained_name)
return InitializeModels.load_using_jld2(
jld2_loaded_obj = InitializeModels.load_using_jld2(
joinpath(path, "$(model.pretrained_name).jld2"), $(string(op)))
return InitializeModels.$(fname_load)(jld2_loaded_obj)
end
return LuxCore.$(fname)(rng, model.layer)
end
end

include("extensions.jl")
include("alexnet.jl")
include("vit.jl")
include("vgg.jl")

Expand Down
40 changes: 40 additions & 0 deletions src/vision/alexnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
AlexNet(; kwargs...)
Create an AlexNet model [krizhevsky2012imagenet](@citep).
## Keyword Arguments
- `pretrained::Bool=false`: If `true`, loads pretrained weights when `LuxCore.setup` is
called.
"""
@concrete struct AlexNet <: AbstractLuxVisionLayer
layer
pretrained_name::Symbol
pretrained::Bool
end

function AlexNet(; pretrained=false)
alexnet = Lux.Chain(;
backbone=Lux.Chain(
Lux.Conv((11, 11), 3 => 64, relu; stride=4, pad=2),
Lux.MaxPool((3, 3); stride=2),
Lux.Conv((5, 5), 64 => 192, relu; pad=2),
Lux.MaxPool((3, 3); stride=2),
Lux.Conv((3, 3), 192 => 384, relu; pad=1),
Lux.Conv((3, 3), 384 => 256, relu; pad=1),
Lux.Conv((3, 3), 256 => 256, relu; pad=1),
Lux.MaxPool((3, 3); stride=2)
),
classifier=Lux.Chain(
Lux.AdaptiveMeanPool((6, 6)),
Lux.FlattenLayer(),
Lux.Dropout(0.5f0),
Lux.Dense(256 * 6 * 6 => 4096, relu),
Lux.Dropout(0.5f0),
Lux.Dense(4096 => 4096, relu),
Lux.Dense(4096 => 1000)
)
)
return AlexNet(alexnet, :alexnet, pretrained)
end
14 changes: 1 addition & 13 deletions src/vision/extensions.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
"""
AlexNet(; kwargs...)
Create an AlexNet model [krizhevsky2012imagenet](@citep).
## Keyword Arguments
- `pretrained::Bool=false`: If `true`, loads pretrained weights when `LuxCore.setup` is
called.
"""
function AlexNet end

"""
ResNet(depth::Int; kwargs...)
Expand Down Expand Up @@ -111,7 +99,7 @@ function ConvMixer end
pretrained::Bool
end

for f in [:AlexNet, :ResNet, :ResNeXt, :GoogLeNet, :DenseNet, :MobileNet, :ConvMixer]
for f in [:ResNet, :ResNeXt, :GoogLeNet, :DenseNet, :MobileNet, :ConvMixer]
f_metalhead = Symbol(f, :Metalhead)
@eval begin
function $(f_metalhead) end
Expand Down

0 comments on commit 0d9f03f

Please sign in to comment.