Skip to content

Commit

Permalink
revert: Hypernet keep in CUDA for now
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 8, 2025
1 parent e2e6470 commit cab3e1f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 83 deletions.
12 changes: 5 additions & 7 deletions examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.15.21"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.14"
Setfield = "1"
Zygote = "0.6"
122 changes: 46 additions & 76 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,31 @@

# ## Package Imports

using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
Reactant
using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Zygote

CUDA.allowscalar(false)

# ## Loading Datasets
function load_dataset(
::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int
) where {dset}
data = dset(:train)
(imgs, labels) = if n_train === nothing
n_train = size(data.features, ndims(data.features))
data.features, data.targets
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
if n_train === nothing
imgs, labels = dset(:train)
else
data = data[1:n_train]
data.features, data.targets
imgs, labels = dset(:train)[1:n_train]
end
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

data = dset(:test)
(imgs, labels) = if n_eval === nothing
n_eval = size(data.features, ndims(data.features))
data.features, data.targets
if n_eval === nothing
imgs, labels = dset(:test)
else
data = data[1:n_eval]
data.features, data.targets
imgs, labels = dset(:test)[1:n_eval]
end
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

return (
DataLoader(
(x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true,
partial=false
),
DataLoader(
(x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false,
partial=false
)
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
)
end

Expand All @@ -49,19 +37,21 @@ function load_datasets(batchsize=256)
end

# ## Implement a HyperNet Layer
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |> getaxes
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
## Generate the weights
ps_new = ComponentArray(Lux.Utils.vec(weight_generator(x)), ca_axes)
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end

# Defining functions on the CompactLuxLayer requires some understanding of how the layer
# is structured, as such we don't recommend doing it unless you are familiar with the
# internals. In this case, we simply write it to ignore the initialization of the
# internals. In this case, we simply write it to ignore the initialization of the
# `core_network` parameters.

function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
Expand All @@ -73,73 +63,61 @@ function create_model()
## Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(1 => 32),
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
)
return HyperNet(weight_generator, core_network)

model = HyperNet(weight_generator, core_network)
return model
end

# ## Define Utility Functions
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, idx)
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(Array(first(model((idx, x), ps, st))))
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end

# ## Training
function train(; dev=reactant_device())
function train()
model = create_model()
dataloaders = load_datasets(256) |> dev

ps, st = Lux.setup(Random.default_rng(), model) |> dev
dataloaders = load_datasets()

if dev isa ReactantDevice
idx = ConcreteRNumber(1)
x = dev(rand(Float32, 28, 28, 1, 256))
model_compiled = @compile model((idx, x), ps, Lux.testmode(st))
else
model_compiled = model
end
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev

train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx]
idx = dev isa ReactantDevice ? ConcreteRNumber(data_idx) : data_idx
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss, ((idx, x), y), train_state
)
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime

train_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, idx
) * 100;
digits=2
)
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, idx
) * 100;
digits=2
)
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)

data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

Expand All @@ -151,23 +129,15 @@ function train(; dev=reactant_device())

test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx]
idx = dev isa ReactantDevice ? ConcreteRNumber(data_idx) : data_idx

train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, idx
) * 100;
digits=2
)
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, idx
) * 100;
digits=2
)
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)

data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

Expand Down

0 comments on commit cab3e1f

Please sign in to comment.