From f7806c83440f1622e5f7010276244fc98c331022 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 2 Dec 2024 19:10:38 +0100 Subject: [PATCH] Use type parameters --- src/lora.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/lora.jl b/src/lora.jl index 4ef5d08..742d277 100644 --- a/src/lora.jl +++ b/src/lora.jl @@ -1,7 +1,7 @@ -struct LoRADense - primary::Dense - proj1::Dense - proj2::Dense +struct LoRADense{P0<:Dense,P1<:Dense,P2<:Dense} + primary::P0 + proj1::P1 + proj2::P2 end """ @@ -10,8 +10,7 @@ end Create a LoRA wrapper around a Dense layer. The second projection matrix is initialized to zero, and only the two projections (and not the primary layer) are trainable. """ function LoRADense(primary::Dense, hidden_dim::Int; init=Flux.kaiming_uniform()) - dim1 = size(primary.weight, 2) - dim2 = size(primary.weight, 1) + dim2, dim1 = size(primary.weight) ld = LoRADense( primary, Dense(dim1 => hidden_dim, bias=false, init = init), @@ -21,8 +20,6 @@ function LoRADense(primary::Dense, hidden_dim::Int; init=Flux.kaiming_uniform()) return ld end -function (lora::LoRADense)(x) - return lora.primary(x) .+ lora.proj2(lora.proj1(x)) -end +(lora::LoRADense)(x) = lora.primary(x) .+ lora.proj2(lora.proj1(x)) Flux.@layer :expand LoRADense trainable=(proj1, proj2) \ No newline at end of file