Skip to content

Commit

Permalink
Merge pull request #10 from LuxDL/ap/lux0.5
Browse files Browse the repository at this point in the history
Updates for Lux 0.5 support
  • Loading branch information
avik-pal authored Aug 20, 2023
2 parents 2415f3b + 7f981de commit ec70b9d
Show file tree
Hide file tree
Showing 17 changed files with 221 additions and 256 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ always_use_return = true
margin = 92
indent = 4
format_docstrings = true
join_lines_based_on_source = false
separate_kwargs_with_semicolon = true
always_for_in = true
38 changes: 37 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,43 @@ steps:
env:
GROUP: "CUDA"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1"
- "1.6"
- "nightly"
adjustments:
- with:
julia: "1.6"
soft_fail: true
- with:
julia: "nightly"
soft_fail: true

- label: ":julia: Julia: {{matrix.julia}} + AMD GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
GROUP: "AMDGPU"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
Expand Down
60 changes: 30 additions & 30 deletions Artifacts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,45 @@ lazy = true
sha256 = "e20107404aba1c2c0ed3fad4314033a2fa600cdc0c55d03bc1bfe4f8e5031105"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/alexnet.tar.gz"

[resnet101]
git-tree-sha1 = "6c9143d40950726405b88db0cc021fa1dcbc0896"
lazy = true
# [resnet101]
# git-tree-sha1 = "6c9143d40950726405b88db0cc021fa1dcbc0896"
# lazy = true

[[resnet101.download]]
sha256 = "3840f05b3d996b2b3ea1e8fb6617775fd60ad6b8769402200fdc9c8b8dca246f"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet101.tar.gz"
# [[resnet101.download]]
# sha256 = "3840f05b3d996b2b3ea1e8fb6617775fd60ad6b8769402200fdc9c8b8dca246f"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet101.tar.gz"

[resnet152]
git-tree-sha1 = "892915c44de37537aad97da3de8a4458dfa36297"
lazy = true
# [resnet152]
# git-tree-sha1 = "892915c44de37537aad97da3de8a4458dfa36297"
# lazy = true

[[resnet152.download]]
sha256 = "6033a1ecc46d7f4ed1139067c5f9f5ea0d247656e9abbbe755c4702ec5a636d6"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet152.tar.gz"
# [[resnet152.download]]
# sha256 = "6033a1ecc46d7f4ed1139067c5f9f5ea0d247656e9abbbe755c4702ec5a636d6"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet152.tar.gz"

[resnet18]
git-tree-sha1 = "1d4a46fee1bb87eeef0ce2c85f63cfe0ff47d4de"
lazy = true
# [resnet18]
# git-tree-sha1 = "1d4a46fee1bb87eeef0ce2c85f63cfe0ff47d4de"
# lazy = true

[[resnet18.download]]
sha256 = "f4041ea1d1ec9bba86c7a5a519daaa49bb096a55fcd4ebf74f0743c8bdcb1c35"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet18.tar.gz"
# [[resnet18.download]]
# sha256 = "f4041ea1d1ec9bba86c7a5a519daaa49bb096a55fcd4ebf74f0743c8bdcb1c35"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet18.tar.gz"

[resnet34]
git-tree-sha1 = "306a8055ae9207ae2a316e31b376254557e481c9"
lazy = true
# [resnet34]
# git-tree-sha1 = "306a8055ae9207ae2a316e31b376254557e481c9"
# lazy = true

[[resnet34.download]]
sha256 = "d62e40ee9213ea9611e3fcedc958df4011da1fa108fb1537bac91e6b7778a3c8"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet34.tar.gz"
# [[resnet34.download]]
# sha256 = "d62e40ee9213ea9611e3fcedc958df4011da1fa108fb1537bac91e6b7778a3c8"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet34.tar.gz"

[resnet50]
git-tree-sha1 = "8c5866edb29b53f581a9ed7148efa1dbccde6133"
lazy = true
# [resnet50]
# git-tree-sha1 = "8c5866edb29b53f581a9ed7148efa1dbccde6133"
# lazy = true

[[resnet50.download]]
sha256 = "275365d76e592c6ea35574853a75ee068767641664e7817aedf394fcd7fea25a"
url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet50.tar.gz"
# [[resnet50.download]]
# sha256 = "275365d76e592c6ea35574853a75ee068767641664e7817aedf394fcd7fea25a"
# url = "https://github.com/avik-pal/Lux.jl/releases/download/weights/resnet50.tar.gz"

[vgg11]
git-tree-sha1 = "ea7e8ef9399a0fe0aad2331781af5d6435950d36"
Expand Down
20 changes: 12 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
name = "Boltz"
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.2.1"
version = "0.3.0"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"

[extensions]
BoltzFluxMetalheadExt = "Metalhead"
BoltzLuxAMDGPUExt = "LuxAMDGPU"
BoltzLuxCUDAExt = "LuxCUDA"
BoltzMetalheadExt = "Metalhead"

[compat]
CUDA = "3, 4"
ChainRulesCore = "1.15"
JLD2 = "0.4"
Lux = "0.4.26"
Metalhead = "0.7"
NNlib = "0.8, 0.9"
Lux = "0.5"
LuxAMDGPU = "0.1"
LuxCUDA = "0.2, 0.3"
Metalhead = "0.7, 0.8"
PackageExtensionCompat = "1"
julia = "1.6"

[extras]
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,22 @@ Pkg.add("Boltz")
## Getting Started

```julia
using Boltz, Lux
using Boltz, Lux, Metalhead

model, ps, st = resnet(:resnet18; pretrained=true)
```

## Changelog


### Updating from v0.2 to v0.3

CUDA is not loaded by default. To use GPUs follow
[Lux.jl documentation](https://lux.csail.mit.edu/stable/manual/gpu_management/).

### Updating from v0.1 to v0.2

We have moved some dependencies into weak dependencies. This means that you will have to
manually load them for certain features to be available.

* To load Flux & Metalhead models, do `using Flux, Metalhead`.
* To load Flux & Metalhead models, do `using Metalhead`.
10 changes: 5 additions & 5 deletions docs/src/api/vision.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ Boltz._vgg_convolutional_layers
| MODEL NAME | FUNCTION | NAME | PRETRAINED | TOP 1 ACCURACY (%) | TOP 5 ACCURACY (%) |
| ---------- | ----------- | --------------------- | :--------: | :----------------: | :----------------: |
| AlexNet | `alexnet` | `:alexnet` || 54.48 | 77.72 |
| ResNet | `resnet` | `:resnet18` | | 68.08 | 88.44 |
| ResNet | `resnet` | `:resnet34` | | 72.13 | 90.91 |
| ResNet | `resnet` | `:resnet50` | | 74.55 | 92.36 |
| ResNet | `resnet` | `:resnet101` | | 74.81 | 92.36 |
| ResNet | `resnet` | `:resnet152` | | 77.63 | 93.84 |
| ResNet | `resnet` | `:resnet18` | 🚫 | 68.08 | 88.44 |
| ResNet | `resnet` | `:resnet34` | 🚫 | 72.13 | 90.91 |
| ResNet | `resnet` | `:resnet50` | 🚫 | 74.55 | 92.36 |
| ResNet | `resnet` | `:resnet101` | 🚫 | 74.81 | 92.36 |
| ResNet | `resnet` | `:resnet152` | 🚫 | 77.63 | 93.84 |
| ConvMixer | `convmixer` | `:small` | 🚫 | | |
| ConvMixer | `convmixer` | `:base` | 🚫 | | |
| ConvMixer | `convmixer` | `:large` | 🚫 | | |
Expand Down
12 changes: 12 additions & 0 deletions ext/BoltzLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BoltzLuxAMDGPUExt

using Boltz, LuxAMDGPU

# NOTE(@avik-pal): Most ROCArray dispatches rely on a contiguous memory layout. Copying
# might be slow but allows us to use the faster and more reliable
# dispatches.
@inline function Boltz._fast_chunk(x::ROCArray, h::Int, n::Int, ::Val{dim}) where {dim}
return copy(selectdim(x, dim, Boltz._fast_chunk(h, n)))
end

end
12 changes: 12 additions & 0 deletions ext/BoltzLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BoltzLuxCUDAExt

using Boltz, LuxCUDA

# NOTE(@avik-pal): Most CuArray dispatches rely on a contiguous memory layout. Copying
# might be slow but allows us to use the faster and more reliable
# dispatches.
@inline function Boltz._fast_chunk(x::CuArray, h::Int, n::Int, ::Val{dim}) where {dim}
return copy(selectdim(x, dim, Boltz._fast_chunk(h, n)))
end

end
7 changes: 1 addition & 6 deletions ext/BoltzFluxMetalheadExt.jl → ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module BoltzFluxMetalheadExt
module BoltzMetalheadExt

using Boltz, Lux, Metalhead
import Boltz: alexnet, convmixer, densenet, googlenet, mobilenet, resnet, resnext
Expand Down Expand Up @@ -30,11 +30,6 @@ function resnet(name::Symbol; pretrained=false, kwargs...)
transform(ResNet(152).layers)
end

# Compatibility with pretrained weights
if pretrained
model = Chain(model[1], model[2])
end

return _initialize_model(name, model; pretrained, kwargs...)
end

Expand Down
11 changes: 2 additions & 9 deletions src/Boltz.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Boltz

using CUDA, Lux, NNlib, Random, Statistics
using Lux, Random, Statistics
# Loading Pretained Weights
using Artifacts, JLD2, LazyArtifacts
# AD Support
Expand All @@ -13,14 +13,7 @@ function __init__()
end

# Define functions. Methods defined in files or in extensions later
for f in (:alexnet,
:convmixer,
:densenet,
:googlenet,
:mobilenet,
:resnet,
:resnext,
:vgg,
for f in (:alexnet, :convmixer, :densenet, :googlenet, :mobilenet, :resnet, :resnext, :vgg,
:vision_transformer)
@eval function $(f) end
end
Expand Down
14 changes: 2 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@ Type-stable and faster version of `MLUtils.chunk`
@inline function _fast_chunk(x::AbstractArray, h::Int, n::Int, ::Val{dim}) where {dim}
return selectdim(x, dim, _fast_chunk(h, n))
end
# NOTE(@avik-pal): Most CuArray dispatches rely on a contiguous memory layout. Copying
# might be slow but allows us to use the faster and more reliable
# dispatches.
@inline function _fast_chunk(x::CuArray, h::Int, n::Int, ::Val{dim}) where {dim}
return copy(selectdim(x, dim, _fast_chunk(h, n)))
end
@inline function _fast_chunk(x::AbstractArray, ::Val{N}, d::Val{D}) where {N, D}
return _fast_chunk.((x,), size(x, D) ÷ N, 1:N, d)
end
Expand Down Expand Up @@ -49,12 +43,8 @@ function _get_pretrained_weights_path(name::String)
end
end

function _initialize_model(name::Symbol,
model;
pretrained::Bool=false,
rng=nothing,
seed=0,
kwargs...)
function _initialize_model(name::Symbol, model; pretrained::Bool=false, rng=nothing,
seed=0, kwargs...)
if pretrained
path = _get_pretrained_weights_path(name)
ps = load(joinpath(path, "$name.jld2"), "parameters")
Expand Down
Loading

2 comments on commit ec70b9d

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/89967

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" ec70b9de0aba4998aef1fb9d0ee652f64123a8c2
git push origin v0.3.0

Please sign in to comment.