Skip to content

Commit

Permalink
Merge remote-tracking branch 'pevnak/phi' into 0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed May 15, 2024
2 parents 80be79f + 3bb4a58 commit 5d765ab
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 29 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrimitiveOneHot = "13d12f88-f12b-451e-9b9f-13b97e01cc85"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StringViews = "354b36f9-a18e-4713-926e-db85100087ba"
Expand Down Expand Up @@ -76,8 +77,9 @@ LightXML = "0.9"
Metal = "1.1"
NNlib = "0.9"
NeuralAttentionlib = "0.2.12"
Pickle = "0.3.3"
Pickle = "0.3.5"
PrimitiveOneHot = "0.1"
SafeTensors = "1.1.1"
Static = "0.7, 0.8"
StringViews = "1"
StructWalk = "0.2"
Expand Down
4 changes: 4 additions & 0 deletions src/huggingface/download.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ using HuggingFaceApi
using HuggingFaceApi: PYTORCH_WEIGHTS_NAME, CONFIG_NAME

const PYTORCH_WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
const SAFETENSOR_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
const SAFETENSOR_WEIGHTS_NAME = "model.safetensors"

const VOCAB_FILE = "vocab.txt"
const VOCAB_JSON_FILE = "vocab.json"
Expand Down Expand Up @@ -45,6 +47,8 @@ hgf_file(model_name, file_name; revision = "main", kws...) = _hgf_download(hgf_f
hgf_model_config(model_name; kws...) = hgf_file(model_name, CONFIG_NAME; kws...)
hgf_model_weight(model_name; kws...) = hgf_file(model_name, PYTORCH_WEIGHTS_NAME; kws...)
hgf_model_weight_index(model_name; kws...) = hgf_file(model_name, PYTORCH_WEIGHTS_INDEX_NAME; kws...)
hgf_model_safetensor_weight(model_name; kws...) = hgf_file(model_name, SAFETENSOR_WEIGHTS_NAME; kws...)
hgf_model_safetensor_weight_index(model_name; kws...) = hgf_file(model_name, SAFETENSOR_WEIGHTS_INDEX_NAME; kws...)
hgf_vocab(model_name; kws...) = hgf_file(model_name, VOCAB_FILE; kws...)
hgf_vocab_json(model_name; kws...) = hgf_file(model_name, VOCAB_JSON_FILE; kws...)
hgf_tokenizer_special_tokens_map(model_name; kws...) = hgf_file(model_name, SPECIAL_TOKENS_MAP_FILE; kws...)
Expand Down
30 changes: 17 additions & 13 deletions src/huggingface/implementation/llama/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,53 @@ using NeuralAttentionlib: $, AbstractAttenOp,
using Static
using ChainRulesCore


llama_rope_attention_score(dim, mask, p) =
llama_rope_attention_score(base, dim, mask, p) =
dropout_score(p) $
normalized_score(softmax) $
masked_score(NeuralAttentionlib.GenericMaskOp(), mask) $
scaled_dot_product_score $
(with_rotary_position_embedding(dim) gptneox_reorder(dim))
(with_rotary_position_embedding(base_position_func(base, dim), dim) gptneox_reorder(dim))

ChainRulesCore.@non_differentiable llama_rope_attention_score(args...)

function llama_rope_grouped_query_attention(dim, head, group, q, k, v, mask = nothing, p = nothing)
function llama_rope_grouped_query_attention(base, dim, head, group, q, k, v, mask = nothing, p = nothing)
return generic_grouped_query_attention(
weighted_sum_mixing, llama_rope_attention_score(dim, mask, p),
weighted_sum_mixing, llama_rope_attention_score(base, dim, mask, p),
head, group, q, k, v)
end
function llama_rope_grouped_query_attention(
::typeof(NeuralAttentionlib.score_returning),
dim, head, group, q, k, v, mask = nothing, p = nothing
base, dim, head, group, q, k, v, mask = nothing, p = nothing
)
return generic_grouped_query_attention(
NeuralAttentionlib.score_returning(weighted_sum_mixing),
llama_rope_attention_score(dim, mask, p),
llama_rope_attention_score(base, dim, mask, p),
head, group, q, k, v)
end

struct CausalLlamaRoPEGroupedQueryAttenOp{D, P} <: AbstractAttenOp
struct CausalLlamaRoPEGroupedQueryAttenOp{F, D, P} <: AbstractAttenOp
base::F
dim::D
head::Int
group::Int
p::P
end
CausalLlamaRoPEGroupedQueryAttenOp(head::Int, group::Int) =
CausalLlamaRoPEGroupedQueryAttenOp(nothing, head, group, nothing)
CausalLlamaRoPEGroupedQueryAttenOp(1e4, head, group)
CausalLlamaRoPEGroupedQueryAttenOp(dim::Int, head::Int, group::Int) =
CausalLlamaRoPEGroupedQueryAttenOp(dim, head, group, nothing)
CausalLlamaRoPEGroupedQueryAttenOp(1e4, dim, head, group)
CausalLlamaRoPEGroupedQueryAttenOp(base, head::Int, group::Int) =
CausalLlamaRoPEGroupedQueryAttenOp(base, nothing, head, group, nothing)
CausalLlamaRoPEGroupedQueryAttenOp(base, dim::Int, head::Int, group::Int) =
CausalLlamaRoPEGroupedQueryAttenOp(base, dim, head, group, nothing)

NeuralAttentionlib.get_attention_func(::CausalLlamaRoPEGroupedQueryAttenOp) = llama_rope_grouped_query_attention
NeuralAttentionlib.get_attention_func_args(op::CausalLlamaRoPEGroupedQueryAttenOp, q, k, v, mask = nothing) =
(op.dim, op.head, op.group, q, k, v, BatchedMask(CausalMask() & mask), op.p)
(op.base, op.dim, op.head, op.group, q, k, v, BatchedMask(CausalMask() & mask), op.p)

Layers.set_dropout(op::CausalLlamaRoPEGroupedQueryAttenOp, p) = CausalLlamaRoPEGroupedQueryAttenOp(op.dim, op.head, op.group, p)
Layers.set_dropout(op::CausalLlamaRoPEGroupedQueryAttenOp, p) = CausalLlamaRoPEGroupedQueryAttenOp(op.base, op.dim, op.head, op.group, p)

const CausalLlamaRoPEGroupedQueryAttenOpWithScore{D, P} = NeuralAttentionlib.WithScore{CausalLlamaRoPEGroupedQueryAttenOp{D, P}}
const CausalLlamaRoPEGroupedQueryAttenOpWithScore{F, D, P} = NeuralAttentionlib.WithScore{CausalLlamaRoPEGroupedQueryAttenOp{F, D, P}}

Layers.argument_names(::CausalLlamaRoPEGroupedQueryAttenOp) = (:hidden_state, :attention_mask)
Layers.apply_on_namedtuple(op::CausalLlamaRoPEGroupedQueryAttenOp, nt::NamedTuple) = Layers.apply_attention_op(op, nt)
1 change: 1 addition & 0 deletions src/huggingface/implementation/llama/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
eos_token_id::Int = 2
pretraining_tp::Int = 1
tie_word_embeddings::Bool = false
rope_theta::Float64 = 1e4
rope_scaling::Nothing = nothing
clean_up_tokenization_spaces::Bool = false
end
5 changes: 3 additions & 2 deletions src/huggingface/implementation/llama/load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ function load_model(_type::Type{<:HGFLlamaPreTrainedModel}, ::Type{<:SelfAttenti
@assert head % kv_head == 0 "The number of query is not dividable by the number of key/value groups"
return_score = cfg[:output_attentions]
factor = Float32(cfg[:initializer_range])
rotary_pe_base = Float64(cfg[:rope_theta])
@assert isnothing(cfg[:rope_scaling]) "Scaling Rotary Embedding is not support yet"
q_weight = getweight(weight_init(dims, dims, factor), Array,
state_dict, joinname(prefix, "q_proj.weight"))
Expand All @@ -87,9 +88,9 @@ function load_model(_type::Type{<:HGFLlamaPreTrainedModel}, ::Type{<:SelfAttenti
qkv_proj = Layers.Fork(Layers.Dense(q_weight), Layers.Dense(k_weight), Layers.Dense(v_weight))
o_proj = Layers.Dense(o_weight)
if grouped_attn
op = CausalLlamaRoPEGroupedQueryAttenOp(head, kv_head)
op = CausalLlamaRoPEGroupedQueryAttenOp(rotary_pe_base, head, kv_head)
else
op = CausalGPTNeoXRoPEMultiheadQKVAttenOp(head_dims, head)
op = CausalGPTNeoXRoPEMultiheadQKVAttenOp(rotary_pe_base, head_dims, head)
end
return_score && (op = WithScore(op))
return SelfAttention(op, qkv_proj, o_proj)
Expand Down
32 changes: 32 additions & 0 deletions src/huggingface/implementation/phi/config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
@defaultdef :phi struct HGFPhiConfigDefault
vocab_size::Int = 51200
hidden_size::Int = 2048
intermediate_size::Int = 8192
num_hidden_layers::Int = 24
num_attention_heads::Int = 32
num_key_value_heads::Nothing = nothing
resid_pdrop::Float64 = 0.0
embd_pdrop::Float64 = 0.0
attention_dropout::Float64 = 0.0
hidden_act::String = "gelu_new"
max_position_embeddings::Int = 2048
initializer_range::Float64 = 0.02
layer_norm_eps::Float64 = 1e-5
use_cache::Bool = true
tie_word_embeddings::Bool = false
rope_theta::Int = 10000
rope_scaling::Nothing = nothing
partial_rotary_factor::Float64 = 0.5
qk_layernorm::Bool = false
bos_token_id::Int = 1
eos_token_id::Int = 2
end

const HGFPhiConfig = HGFConfig{:phi}

function HGFConfig{:phi}(cfg, overwrite)
if !haskey(cfg, :num_key_value_heads)
overwrite[:num_key_value_heads] = get(cfg, :num_attention_heads, 32)
end
return HGFConfig(:phi, cfg, overwrite)
end
213 changes: 213 additions & 0 deletions src/huggingface/implementation/phi/load.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
using ..Layers
using ..Layers: CompositeEmbedding, SelfAttention
using ChainRulesCore
using Functors
using Static

using NeuralAttentionlib
using NeuralAttentionlib: WithScore


abstract type HGFPhiPreTrainedModel <: HGFPreTrainedModel end

struct HGFPhiModel{E, DEC} <: HGFPhiPreTrainedModel
embed::E
decoder::DEC
end
@functor HGFPhiModel

(model::HGFPhiModel)(nt::NamedTuple) = model.decoder(model.embed(nt))

for T in :[
HGFPhiForCausalLM,
# HGFPhiForSequenceClassification,
# HGFPhiForTokenClassification,
].args
@eval begin
@hgfdefmodel $T HGFPhiPreTrainedModel
end
end

basemodelkey(::Type{<:HGFPhiPreTrainedModel}) = :model
isbasemodel(::Type{<:HGFPhiModel}) = true
isbasemodel(::Type{<:HGFPhiPreTrainedModel}) = false

get_model_type(::Val{:phi}) = (
model = HGFPhiModel,
forcausallm = HGFPhiForCausalLM,
)

function load_model(_type::Type{HGFPhiModel}, cfg, state_dict, prefix)
embed = load_model(_type, CompositeEmbedding, cfg, state_dict, prefix)
decoder = load_model(_type, TransformerBlock, cfg, state_dict, prefix)
return HGFPhiModel(embed, decoder)
end

function load_model(_type::Type{HGFPhiForCausalLM}, cfg, state_dict, prefix)
model = load_model(HGFPhiModel, cfg, state_dict, joinname(prefix, "model"))
if cfg[:tie_word_embeddings]
embedding = model.embed.token.embeddings
else
vocab_size, dims, factor = cfg[:vocab_size], cfg[:hidden_size], Float32(cfg[:initializer_range])
embedding = getweight(weight_init(vocab_size, dims, factor), Layers.Embed,
state_dict, joinname(prefix, "lm_head.weight"))
end
bias = getweight(zero_init(vocab_size), Array, state_dict, joinname(prefix, "lm_head.bias"))
lmhead = Layers.EmbedDecoder(Layers.Embed(embedding), bias)
return HGFPhiForCausalLM(model, Layers.Branch{(:logit,), (:hidden_state,)}(lmhead))
end

function load_model(_type::Type{<:HGFPhiPreTrainedModel}, ::Type{<:CompositeEmbedding}, cfg, state_dict, prefix)
vocab_size, dims = cfg[:vocab_size], cfg[:hidden_size]
factor = Float32(cfg[:initializer_range])
token_weight = getweight(weight_init(vocab_size, dims, factor), Layers.Embed,
state_dict, joinname(prefix, "embed_tokens.weight"))
p = cfg[:embd_pdrop]; p = iszero(p) ? nothing : p
embed = CompositeEmbedding(token = Layers.Embed(token_weight))
return Layers.DropoutLayer(embed, p)
end

function load_model(
_type::Type{<:HGFPhiPreTrainedModel}, ::Type{<:Layers.Chain{<:Tuple{Layers.Dense, Layers.Dense}}},
cfg, state_dict, prefix
)
dims = cfg[:hidden_size]
ff_dims = cfg[:intermediate_size]
factor = Float32(cfg[:initializer_range])
act = ACT2FN[Symbol(cfg[:hidden_act])]
wi_weight = getweight(weight_init(dims, ff_dims, factor), Array,
state_dict, joinname(prefix, "fc1.weight"))
wi_bias = getweight(zero_init(ff_dims), Array, state_dict, joinname(prefix, "fc1.bias"))
wo_weight = getweight(weight_init(ff_dims, dims, factor), Array,
state_dict, joinname(prefix, "fc2.weight"))
wo_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "fc2.bias"))
return Layers.Chain(Layers.Dense(act, wi_weight, wi_bias), Layers.Dense(wo_weight, wo_bias))
end

function load_model(_type::Type{<:HGFPhiPreTrainedModel}, ::Type{<:SelfAttention}, cfg, state_dict, prefix)
head, dims = cfg[:num_attention_heads], cfg[:hidden_size]
@assert dims % head == 0 "The hidden size is not a multiple of the number of attention heads."
p = cfg[:attention_dropout]; p = iszero(p) ? nothing : p
head_dims = div(dims, head)
kv_head = something(cfg[:num_key_value_heads], head)
grouped_attn = kv_head != head
@assert head % kv_head == 0 "The number of query is not dividable by the number of key/value groups"
return_score = cfg[:output_attentions]
factor = Float32(cfg[:initializer_range])
rotary_dim = floor(Int, cfg[:partial_rotary_factor] * head_dims)
rotary_pe_base = Float64(cfg[:rope_theta])
@assert isnothing(cfg[:rope_scaling]) "Scaling Rotary Embedding is not support yet"
kv_dims = kv_head * head_dims
q_weight = getweight(weight_init(dims, dims, factor), Array, state_dict, joinname(prefix, "q_proj.weight"))
q_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "q_proj.bias"))
k_weight = getweight(weight_init(dims, kv_dims, factor), Array, state_dict, joinname(prefix, "k_proj.weight"))
k_bias = getweight(zero_init(kv_dims), Array, state_dict, joinname(prefix, "k_proj.bias"))
v_weight = getweight(weight_init(dims, kv_dims, factor), Array, state_dict, joinname(prefix, "v_proj.weight"))
v_bias = getweight(zero_init(kv_dims), Array, state_dict, joinname(prefix, "v_proj.bias"))
o_weight = getweight(weight_init(dims, dims, factor), Array, state_dict, joinname(prefix, "dense.weight"))
o_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "dense.bias"))
query = Layers.Dense(q_weight, q_bias)
key = Layers.Dense(k_weight, k_bias)
value = Layers.Dense(v_weight, v_bias)
if cfg[:qk_layernorm]
ln_ϵ = Float32(cfg[:layer_norm_eps])
q_layernorm = _load_layernorm(state_dict, joinname(lprefix, "q_layernorm"), head_dims, ln_ϵ)
k_layernorm = _load_layernorm(state_dict, joinname(lprefix, "k_layernorm"), head_dims, ln_ϵ)
query = Layers.Chain(query, q_layernorm)
key = Layers.Chain(key, k_layernorm)
end
qkv_proj = Layers.Fork(query, key, value)
o_proj = Layers.Dense(o_weight, o_bias)
if grouped_attn
op = CausalLlamaRoPEGroupedQueryAttenOp(rotary_pe_base, rotary_dim, head, kv_head, p)
else
op = CausalGPTNeoXRoPEMultiheadQKVAttenOp(rotary_pe_base, rotary_dim, head, p)
end
return_score && (op = WithScore(op))
return SelfAttention(op, qkv_proj, o_proj)
end

function load_model(::Type{<:HGFPhiPreTrainedModel}, ::Type{<:Layers.LayerNorm}, cfg, state_dict, prefix)
dims = cfg[:hidden_size]
ln_ϵ = Float32(cfg[:layer_norm_eps])
ln_weight = getweight(one_init(dims), Array, state_dict, joinname(prefix, "weight"))
ln_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "bias"))
return Layers.LayerNorm(ln_weight, ln_bias, ln_ϵ)
end

function load_model(_type::Type{<:HGFPhiPreTrainedModel}, ::Type{<:TransformerBlock}, cfg, state_dict, prefix)
n = cfg[:num_hidden_layers]
p = cfg[:resid_pdrop]; p = iszero(p) ? nothing : p
collect_output = cfg[:output_attentions] || cfg[:output_hidden_states]
blocks = []
for i = 1:n
lprefix = joinname(prefix, :layers, i-1)
ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(lprefix, "input_layernorm"))
sa = load_model(_type, SelfAttention, cfg, state_dict, joinname(lprefix, "self_attn"))
sa = Layers.DropoutLayer(sa, p)
ff = load_model(_type, Layers.Chain{Tuple{Layers.Dense, Layers.Dense}}, cfg, state_dict, joinname(lprefix, "mlp"))
ff = Layers.DropoutLayer(ff, p)
block = ParallelPreNormTransformerBlock(sa, ff, ln)
push!(blocks, block)
end
collect_f = collect_output ? Layers.collect_outputs : nothing
trf = Transformer(Tuple(blocks), collect_f)
final_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "final_layernorm"))
return Layers.Chain(trf, final_ln)
end


function get_state_dict(m::HGFPhiModel, state_dict, prefix)
get_state_dict(HGFPhiModel, m.embed, state_dict, prefix)
get_state_dict(HGFPhiModel, m.decoder[1], state_dict, prefix)
get_state_dict(HGFPhiModel, m.decoder[2], state_dict, joinname(prefix, "final_layernorm"))
return state_dict
end

function get_state_dict(m::HGFPhiForCausalLM, state_dict, prefix)
get_state_dict(m.model, state_dict, joinname(prefix, "model"))
get_state_dict(HGFPhiModel, m.cls.layer, state_dict, joinname(prefix, "lm_head"))
return state_dict
end

function get_state_dict(p::Type{<:HGFPhiPreTrainedModel}, m::CompositeEmbedding, state_dict, prefix)
get_state_dict(p, m.token, state_dict, joinname(prefix, "embed_tokens"))
return state_dict
end

function get_state_dict(p::Type{<:HGFPhiPreTrainedModel}, m::SelfAttention, state_dict, prefix)
q, k, v = m.qkv_proj.layers
if q isa Layers.Chain
get_state_dict(p, q[1], state_dict, joinname(prefix, "q_proj"))
get_state_dict(p, k[1], state_dict, joinname(prefix, "k_proj"))
get_state_dict(p, q[2], state_dict, joinname(prefix, "q_layernorm"))
get_state_dict(p, k[2], state_dict, joinname(prefix, "k_layernorm"))
else
get_state_dict(p, q, state_dict, joinname(prefix, "q_proj"))
get_state_dict(p, k, state_dict, joinname(prefix, "k_proj"))
end
get_state_dict(p, v, state_dict, joinname(prefix, "v_proj"))
get_state_dict(p, m.o_proj, state_dict, joinname(prefix, "dense"))
return state_dict
end

function get_state_dict(p::Type{<:HGFPhiPreTrainedModel}, m::Layers.Chain{<:Tuple{Layers.Dense, Layers.Dense}},
state_dict, prefix)
get_state_dict(p, m[1], state_dict, joinname(prefix, "fc1"))
get_state_dict(p, m[2], state_dict, joinname(prefix, "fc2"))
return state_dict
end

function get_state_dict(p::Type{<:HGFPhiPreTrainedModel}, m::ParallelPreNormTransformerBlock, state_dict, prefix)
get_state_dict(p, m.norm, state_dict, joinname(prefix, "input_layernorm"))
get_state_dict(p, m.attention, state_dict, joinname(prefix, "self_attn"))
get_state_dict(p, m.feedforward, state_dict, joinname(prefix, "mlp"))
return state_dict
end

function get_state_dict(p::Type{<:HGFPhiPreTrainedModel}, m::Transformer, state_dict, prefix)
for (i, t) in enumerate(m.blocks)
get_state_dict(p, t, state_dict, joinname(prefix, :layers, i-1))
end
return state_dict
end
2 changes: 2 additions & 0 deletions src/huggingface/implementation/phi/phi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("config.jl")
include("load.jl")
Loading

0 comments on commit 5d765ab

Please sign in to comment.