Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge upstream #198

Closed
wants to merge 13 commits into from
1,280 changes: 1,280 additions & 0 deletions Manifest.toml

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ LightXML = "9c8b4983-aa76-5018-a973-4c85ecc9e179"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -42,6 +41,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"
WordTokenizers = "796a5d58-b03d-544a-977e-18100b691f6e"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
BytePairEncoding = "0.4"
Expand All @@ -52,16 +52,15 @@ DataStructures = "0.18"
DoubleArrayTries = "0.1"
Fetch = "0.1.3"
FillArrays = "0.13, 1"
Flux = "0.13.4"
Flux = "0.13, 0.14"
FuncPipelines = "0.2.3"
Functors = "0.2, 0.3, 0.4"
HTTP = "0.9, 1"
HuggingFaceApi = "0.1"
JSON3 = "1.12"
LRUCache = "1.5"
LightXML = "0.9"
NNlib = "0.8"
NNlibCUDA = "0.2"
NNlib = "0.8, 0.9"
NeuralAttentionlib = "0.2.12"
Pickle = "0.3"
PrimitiveOneHot = "0.1"
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# THIS FORK IS CUSTOMIZED FOR DECODER_ONLY MODELS.

<div align="center"> <img src="images/transformerslogo.png" alt="Transformers.jl" width="512"></img></div>

[![Build status](https://github.com/chengchingwen/Transformers.jl/workflows/CI/badge.svg)](https://github.com/chengchingwen/Transformers.jl/actions)
Expand Down
2 changes: 1 addition & 1 deletion src/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using NeuralAttentionlib

function _togpudevice(x, cache)
# https://github.com/FluxML/Flux.jl/blob/79971741ed8454cdf6a66515799a0c4b864f564a/src/functor.jl#L206-L209
Flux.check_use_cuda()
# Flux.check_use_cuda()
return Flux.fmap(
x -> Flux.adapt(Flux.FluxCUDAAdaptor(), x),
x; exclude = Flux._isleaf, cache)
Expand Down
4 changes: 3 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ export Seq2Seq, Transformer,
TransformerBlock, TransformerDecoderBlock,
PreNormTransformerBlock, PostNormTransformerBlock,
PreNormTransformerDecoderBlock, PostNormTransformerDecoderBlock,
Embed, EmbedDecoder, FixedLenPositionEmbed, SinCosPositionEmbed
Embed, EmbedDecoder, FixedLenPositionEmbed, SinCosPositionEmbed,
RotaryPositionEmbed

include("./utils.jl")
include("./architecture.jl")
include("./base.jl")
include("./embed.jl")
include("./layer.jl")
include("./attention_op.jl")
include("./causal_flash_op.jl")
include("./structwalk.jl")
include("./testmode.jl")

Expand Down
201 changes: 201 additions & 0 deletions src/layers/causal_flash_op.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
using CUDA
@inline function compute_shmem_size(d, Bs)
return (Bs * d * 3 + 4 * d + Bs * Bs) * sizeof(Float32)
end

"""
setMaxShmem(shmem)

Set the maximum shared memory size for the current device to `shmem` KB.
"""
function setMaxShmem(shmem)
kernel = cufunction(flash_attention_kernel, NTuple{4, CuDeviceArray{Float16, 4, 1}})
return CUDA.cuFuncSetAttribute(kernel.fun,
CUDA.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shmem * 1024)
end

function _checkbounds(Q, K, V)
sQ, sK, sV = size(Q), size(K), size(V)
sK != sV && throw(DimensionMismatch("K and V must have the same shape"))
sQ[3:4] != sK[3:4] != sV[3:4] &&
throw(DimensionMismatch("Q, K and V must have the same batch size and head size"))
return sQ[1] != sK[2] != sV[2] &&
throw(DimensionMismatch("Q, K and V must have the same hidden dimension"))
end

@inline function mod1_pow2(x, y)
r = x & (y - 1)
return ifelse(r == 0, y, r)
end


function causal_flash_attention_kernel(Q, K, V, O)
d = size(K, 1)
power = trailing_zeros(d)
tx = threadIdx().x
Bs = blockDim().x # assume Br == Bc
col = (blockIdx().x - 1) * Bs + tx
# skip computation if col < row

# acllocate shared memory
T = eltype(Q)
shmem_offset = 0
q = CuDynamicSharedArray(T, (Bs + 2, d), shmem_offset) # pad 2 rows to avoid bank conflicts
shmem_offset += sizeof(q)
o = CuDynamicSharedArray(T, (Bs + 2, d), shmem_offset) # pad 2 row to avoid bank conflicts
shmem_offset += sizeof(o)
k = CuDynamicSharedArray(T, (d, Bs), shmem_offset) # pad 2 rows to avoid bank conflicts
shmem_offset += sizeof(k)
s = CuDynamicSharedArray(T, (Bs, Bs), shmem_offset)

# load Q to shared memory, note that this is done only once
Q_offset = d * Bs * (blockIdx().x - 1) +
stride(Q, 3) * (blockIdx().y - 1) +
stride(Q, 4) * (blockIdx().z - 1)
K_offset = stride(K, 3) * (blockIdx().y - 1) + stride(K, 4) * (blockIdx().z - 1)

for i in 0:(d - 1)
idx = i * Bs + tx
row = mod1_pow2(idx, d)
col = (idx - row) >> power + 1
@inbounds q[col, row] = Q[idx + Q_offset]
@inbounds o[idx] = zero(T)
@inbounds k[idx] = K[idx + K_offset]
end

sync_threads()

# initialize lseᵢ and mᵢ
lseᵢ = -typemax(T)
mᵢ = -typemax(T)

# the inner loop is serial
for _ in 1:cld(size(K, 2), Bs) # iterate over Bs elements in sequence
# initialize mᵢⱼ
mᵢⱼ = lseᵢ

# compute s=Q^TK
# s = (Bs, Bs)
#inf_block = true
for n in 1:Bs
if Q_offset + tx < K_offset + n
s[tx, n] = -Inf
continue
end
#inf_block = false

tmp = zero(T)
for m in 1:d
@inbounds tmp = CUDA.fma(q[tx, m], k[m, n], tmp)
end
s[tx, n] = tmp
@inbounds mᵢⱼ = max(mᵢⱼ, s[tx, n])
end
#inf_block && return nothing

sync_threads()

# compute P̃ᵢⱼ and lᵢⱼ
lᵢⱼ = zero(T)
for n in 1:Bs
@inbounds tmp = exp(s[tx, n] - mᵢⱼ)
@inbounds s[tx, n] = tmp
lᵢⱼ += tmp
end

# Load V to shared memory, which shares the same memory with k
for i in 0:(d - 1)
idx = i * Bs + tx
row = mod1_pow2(idx, d)
col = (idx - row) >> power + 1
@inbounds k[row, col] = V[idx + K_offset]
end

sync_threads()

# update o
for m in 1:d
tmp = o[tx, m] * exp(mᵢ - mᵢⱼ)
for n in 1:Bs
@inbounds tmp = CUDA.fma(s[tx, n], k[m, n], tmp) # k[m, n] * s[n, tx]
end
@inbounds o[tx, m] = tmp
end

mᵢ = mᵢⱼ
lseᵢ = mᵢⱼ + log(exp(lseᵢ - mᵢⱼ) + lᵢⱼ)

K_offset += Bs * d

# update k
for i in 0:(d - 1)
idx = i * Bs + tx
@inbounds k[idx] = K[idx + K_offset]
end
sync_threads()
end

for m in 1:d
@inbounds o[tx, m] = o[tx, m] * exp(mᵢ - lseᵢ)
end
sync_threads()

# write to O
for i in 0:(d - 1)
idx = i * Bs + tx
row = mod1_pow2(idx, d)
col = (idx - row) >> power + 1
@inbounds O[idx + Q_offset] = o[col, row]
end

return nothing
end

function causal_flash_attention(Q::CuArray{T, 4}, K::CuArray{T, 4}, V::CuArray{T, 4}) where {T}
_checkbounds(Q, K, V)
O = similar(Q)
kernel = @cuda launch=false causal_flash_attention_kernel(Q, K, V, O)
d, N, H, B = size(Q)
get_shmem = Base.Fix1(compute_shmem_size, d)
config = launch_configuration(kernel.fun; shmem=get_shmem, max_threads=256)

Bs = min(N, config.threads)
threads = (Bs, 1, 1)
blocks = (cld(N, Bs), H, B)
shmem = get_shmem(Bs)

kernel(Q, K, V, O; threads=threads, blocks=blocks, shmem=shmem)
return O
end

function causal_flash_attention(n_heads::Int, Q, K, V)
@assert ndims(Q) == ndims(K) == ndims(V) == 3 "Q, K, and V should be of size (d*h, n, b)"
Q_fa, K_fa, V_fa = Transformers_to_Flash(n_heads, Q), Transformers_to_Flash(n_heads, K), Transformers_to_Flash(n_heads, V)
O_fa = causal_flash_attention(Q_fa, K_fa, V_fa)
Flash_to_Transformers(O_fa)
end
function Transformers_to_Flash(n_heads::Int, arr)
d = Int(size(arr, 1) / n_heads)
N, B = size(arr, 2), size(arr, 3)
arr_4d = reshape(arr, d, n_heads, N, B)
perm =(1, 3, 2, 4)
permutedims(arr_4d, perm)
end

function Flash_to_Transformers(arr)
arr = permutedims(arr, (1,3,2,4))
N, B = size(arr, 3), size(arr, 4)
(hidden_state=reshape(arr, :, N, B),)
end


struct CausalFlashMultiheadQKVAttenOp{F} <: AbstractAttenOp
head::Int
p::F
end
CausalFlashMultiheadQKVAttenOp(head) = CausalFlashMultiheadQKVAttenOp(head, nothing)
NeuralAttentionlib.get_attention_func(::CausalFlashMultiheadQKVAttenOp) = causal_flash_attention
NeuralAttentionlib.get_attention_func_args(op::CausalFlashMultiheadQKVAttenOp, q, k, v, mask = nothing) = (op.head, q, k, v)
argument_names(::CausalFlashMultiheadQKVAttenOp) = (:hidden_state, :attention_mask)
apply_on_namedtuple(op::CausalFlashMultiheadQKVAttenOp, nt::NamedTuple) = apply_attention_op(op, nt)
4 changes: 4 additions & 0 deletions src/layers/embed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ function Base.show(io::IO, embed::SinCosPositionEmbed)
end
@fluxlayershow SinCosPositionEmbed false

struct RotaryPositionEmbed <: AbstractEmbedding end
(embed::RotaryPositionEmbed)(x) = NeuralAttentionlib.with_rotary_position_embedding(x)
@fluxlayershow RotaryPositionEmbed false

"""
ApplyEmbed([apply = .+,] embed)

Expand Down
44 changes: 25 additions & 19 deletions src/layers/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,24 @@ end

(b::TransformerBlock)(nt::NamedTuple) = apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.attention, nt))

struct TransformerDecoderBlock{A, C, F} <: AbstractTransformerBlock
struct TransformerDecoderBlock{A, F} <: AbstractTransformerBlock
attention::A
crossattention::C
feedforward::F
end
@functor TransformerDecoderBlock

argument_names(b::TransformerDecoderBlock) = Base.merge_names(
Base.merge_names(argument_names(b.crossattention), argument_names(b.attention)),
argument_names(b.attention),
argument_names(b.feedforward)
)

# performs attention on nt, returns the result as an NamedTuple
# then performs crossattention on the result, returns the result as an NamedTuple
# then performs feedforward on the result, returns the result as an NamedTuple
# (b::TransformerDecoderBlock)(nt::NamedTuple) =
# apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.crossattention, apply_on_namedtuple(b.attention, nt)))
(b::TransformerDecoderBlock)(nt::NamedTuple) =
apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.crossattention, apply_on_namedtuple(b.attention, nt)))
apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.attention, nt))

struct Residual{L} <: LayerStruct
layer::L
Expand Down Expand Up @@ -97,10 +101,10 @@ end

const PreNormTransformerBlock{A, LN1, F, LN2} = TransformerBlock{ PreNormResidual{A, LN1}, PreNormResidual{F, LN2}}
const PostNormTransformerBlock{A, LN1, F, LN2} = TransformerBlock{PostNormResidual{A, LN1}, PostNormResidual{F, LN2}}
const PreNormTransformerDecoderBlock{A, LN1, C, LN2, F, LN3} =
TransformerDecoderBlock{ PreNormResidual{A, LN1}, PreNormResidual{C, LN2}, PreNormResidual{F, LN3}}
const PostNormTransformerDecoderBlock{A, LN1, C, LN2, F, LN3} =
TransformerDecoderBlock{PostNormResidual{A, LN1}, PostNormResidual{C, LN2}, PostNormResidual{F, LN3}}
const PreNormTransformerDecoderBlock{A, LN1, #=C, LN2,=# F, LN3} =
TransformerDecoderBlock{ PreNormResidual{A, LN1}, #=PreNormResidual{C, LN2},=# PreNormResidual{F, LN3}}
const PostNormTransformerDecoderBlock{A, LN1, #=C, LN2,=# F, LN3} =
TransformerDecoderBlock{PostNormResidual{A, LN1}, #=PostNormResidual{C, LN2},=# PostNormResidual{F, LN3}}

function Base.show(io::IO, t::PreNormTransformerBlock)
print(io, "PreNormTransformerBlock(");
Expand All @@ -115,13 +119,13 @@ end
function Base.show(io::IO, t::PreNormTransformerDecoderBlock)
print(io, "PreNormTransformerDecoderBlock(")
show(io, t.attention.layer); print(io, ", "); show(io, t.attention.norm); print(io, ", ");
show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", ");
# show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", ");
show(io, t.feedforward.layer); print(io, ", "); show(io, t.feedforward.norm); print(io, ')')
end
function Base.show(io::IO, t::PostNormTransformerDecoderBlock)
print(io, "PostNormTransformerDecoderBlock(")
show(io, t.attention.layer); print(io, ", "); show(io, t.attention.norm); print(io, ", ");
show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", ");
#show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", ");
show(io, t.feedforward.layer); print(io, ", "); show(io, t.feedforward.norm); print(io, ')')
end
_show_name(t::PreNormTransformerBlock) = "PreNormTransformerBlock"
Expand All @@ -131,8 +135,8 @@ _show_name(t::PostNormTransformerDecoderBlock) = "PostNormTransformerDecoderBloc

Flux._show_children(t::PreNormTransformerBlock) = (t.attention.layer, t.attention.norm, t.feedforward.layer, t.feedforward.norm)
Flux._show_children(t::PostNormTransformerBlock) = (t.attention.layer, t.attention.norm, t.feedforward.layer, t.feedforward.norm)
Flux._show_children(t::PreNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, t.crossattention.layer, t.crossattention.norm, t.feedforward.layer, t.feedforward.norm)
Flux._show_children(t::PostNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, t.crossattention.layer, t.crossattention.norm, t.feedforward.layer, t.feedforward.norm)
Flux._show_children(t::PreNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, #=t.crossattention.layer, t.crossattention.norm,=# t.feedforward.layer, t.feedforward.norm)
Flux._show_children(t::PostNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, #= t.crossattention.layer, t.crossattention.norm, =# t.feedforward.layer, t.feedforward.norm)

#############################################

Expand All @@ -147,7 +151,13 @@ function (sa::SelfAttention)(nt::NamedTuple)
qkv = apply_on_namedtuple(sa.qkv_proj, nt)
a = apply_on_namedtuple(sa.attention_op, qkv)
y = apply_on_namedtuple(sa.o_proj, a)
return y
return y
# NOTE: instead of returning y, we return a copy of y, because
# there is some sort of memory leak when using distributed for Jevo specifically,
# I suspect related to gradients. This cuts off gradient flow.
#hidden_state = zeros(Float32, size(y.hidden_state)) |> Flux.gpu
#hidden_state .= y.hidden_state
#return (hidden_state = hidden_state, attention_mask = y.attention_mask)
end

struct CrossAttention{A, Q, KV, O} <: LayerStruct
Expand Down Expand Up @@ -515,18 +525,14 @@ function PostNormTransformerDecoderBlock(
)
sa = SelfAttention(head, hidden_size, head_hidden_size;
dropout = attention_dropout, causal = true, return_score = return_self_attention_score)
ca = CrossAttention(head, hidden_size, head_hidden_size; dropout = cross_attention_dropout, return_score)
ff1 = Dense(act, hidden_size, intermediate_size)
ff2 = Dense(intermediate_size, hidden_size)
return TransformerDecoderBlock(
PostNormResidual(
DropoutLayer(sa, dropout),
sa,
LayerNorm(hidden_size)),
PostNormResidual(
DropoutLayer(ca, dropout),
LayerNorm(hidden_size)),
PostNormResidual(
DropoutLayer(Chain(ff1, ff2), dropout),
Chain(ff1, ff2),
LayerNorm(hidden_size)))
end

Expand Down