Skip to content

Commit

Permalink
update phi & test
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed May 15, 2024
1 parent 5d765ab commit f475ffe
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 39 deletions.
4 changes: 1 addition & 3 deletions src/huggingface/implementation/phi/config.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@defaultdef :phi struct HGFPhiConfigDefault
@hgfcfg :phi struct HGFPhiConfig
vocab_size::Int = 51200
hidden_size::Int = 2048
intermediate_size::Int = 8192
Expand All @@ -22,8 +22,6 @@
eos_token_id::Int = 2
end

const HGFPhiConfig = HGFConfig{:phi}

function HGFConfig{:phi}(cfg, overwrite)
if !haskey(cfg, :num_key_value_heads)
overwrite[:num_key_value_heads] = get(cfg, :num_attention_heads, 32)
Expand Down
31 changes: 4 additions & 27 deletions src/huggingface/implementation/phi/load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,12 @@ using Static
using NeuralAttentionlib
using NeuralAttentionlib: WithScore


abstract type HGFPhiPreTrainedModel <: HGFPreTrainedModel end

struct HGFPhiModel{E, DEC} <: HGFPhiPreTrainedModel
embed::E
decoder::DEC
end
@functor HGFPhiModel

(model::HGFPhiModel)(nt::NamedTuple) = model.decoder(model.embed(nt))

for T in :[
HGFPhiForCausalLM,
# HGFPhiForSequenceClassification,
# HGFPhiForTokenClassification,
].args
@eval begin
@hgfdefmodel $T HGFPhiPreTrainedModel
end
end
@hgfdef Phi (
Model => (embed, decoder),
ForCausalLM,
)

basemodelkey(::Type{<:HGFPhiPreTrainedModel}) = :model
isbasemodel(::Type{<:HGFPhiModel}) = true
isbasemodel(::Type{<:HGFPhiPreTrainedModel}) = false

get_model_type(::Val{:phi}) = (
model = HGFPhiModel,
forcausallm = HGFPhiForCausalLM,
)

function load_model(_type::Type{HGFPhiModel}, cfg, state_dict, prefix)
embed = load_model(_type, CompositeEmbedding, cfg, state_dict, prefix)
Expand Down
25 changes: 16 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,29 @@ dzeros(arg...) = zeros(arg...) |> device
const tests = [
"tokenizer",
"huggingface",
"loss.jl",
"grad.jl",
]

Random.seed!(0)

@testset "Transformers" begin
for t in tests
name = titlecase(t)
@testset "$name" begin
@info "Test $name"
for f readdir(joinpath(@__DIR__, t))
endswith(f, ".jl") || continue
t == "huggingface" && f == "tokenizer.jl" && !envget("JL_TRF_TEST_TKR") && continue
include(joinpath(@__DIR__, t, f))
path = joinpath(@__DIR__, t)
if isdir(path)
name = titlecase(t)
@testset "$name" begin
@info "Test $name"
for f readdir(path)
endswith(f, ".jl") || continue
t == "huggingface" && f == "tokenizer.jl" && !envget("JL_TRF_TEST_TKR") && continue
include(joinpath(path, f))
end
end
else
name = titlecase(first(splitext(t)))
@info "Test $name"
include(path)
end
end
include("loss.jl")
include("grad.jl")
end

0 comments on commit f475ffe

Please sign in to comment.