Skip to content

Commit

Permalink
Use type parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Dec 2, 2024
1 parent fc20672 commit f7806c8
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/lora.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
struct LoRADense
primary::Dense
proj1::Dense
proj2::Dense
struct LoRADense{P0<:Dense,P1<:Dense,P2<:Dense}
primary::P0
proj1::P1
proj2::P2
end

"""
Expand All @@ -10,8 +10,7 @@ end
Create a LoRA wrapper around a Dense layer. The second projection matrix is initialized to zero, and only the two projections (and not the primary layer) are trainable.
"""
function LoRADense(primary::Dense, hidden_dim::Int; init=Flux.kaiming_uniform())
dim1 = size(primary.weight, 2)
dim2 = size(primary.weight, 1)
dim2, dim1 = size(primary.weight)
ld = LoRADense(
primary,
Dense(dim1 => hidden_dim, bias=false, init = init),
Expand All @@ -21,8 +20,6 @@ function LoRADense(primary::Dense, hidden_dim::Int; init=Flux.kaiming_uniform())
return ld
end

function (lora::LoRADense)(x)
return lora.primary(x) .+ lora.proj2(lora.proj1(x))
end
(lora::LoRADense)(x) = lora.primary(x) .+ lora.proj2(lora.proj1(x))

Flux.@layer :expand LoRADense trainable=(proj1, proj2)

0 comments on commit f7806c8

Please sign in to comment.