Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TOML: Improve type-stability #55016

Merged
merged 1 commit into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ struct TOMLCache{Dates}
d::Dict{String, CachedTOMLDict}
end
TOMLCache(p::TOML.Parser) = TOMLCache(p, Dict{String, CachedTOMLDict}())
# TODO: Delete this converting constructor once Pkg stops using it
TOMLCache(p::TOML.Parser, d::Dict{String, Dict{String, Any}}) = TOMLCache(p, convert(Dict{String, CachedTOMLDict}, d))

const TOML_CACHE = TOMLCache(TOML.Parser{nothing}())
Expand Down
100 changes: 53 additions & 47 deletions base/toml_parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ mutable struct Parser{Dates}

# Filled in in case we are parsing a file to improve error messages
filepath::Union{String, Nothing}

# Optionally populate with the Dates stdlib to change the type of Date types returned
Dates::Union{Module, Nothing} # TODO: remove once Pkg is updated
end

function Parser{Dates}(str::String; filepath=nothing) where {Dates}
Expand All @@ -106,8 +103,7 @@ function Parser{Dates}(str::String; filepath=nothing) where {Dates}
IdSet{Any}(), # static_arrays
IdSet{TOMLDict}(), # defined_tables
root,
filepath,
nothing
filepath
)
startup(l)
return l
Expand Down Expand Up @@ -495,8 +491,10 @@ function recurse_dict!(l::Parser, d::Dict, dotted_keys::AbstractVector{String},
d = d::TOMLDict
key = dotted_keys[i]
d = get!(TOMLDict, d, key)
if d isa Vector
if d isa Vector{Any}
d = d[end]
elseif d isa Vector
return ParserError(ErrKeyAlreadyHasValue)
end
check && @try check_allowed_add_key(l, d, i == length(dotted_keys))
end
Expand Down Expand Up @@ -537,7 +535,7 @@ function parse_array_table(l)::Union{Nothing, ParserError}
end
d = @try recurse_dict!(l, l.root, @view(table_key[1:end-1]), false)
k = table_key[end]
old = get!(() -> [], d, k)
old = get!(() -> Any[], d, k)
if old isa Vector
if old in l.static_arrays
return ParserError(ErrAddArrayToStaticArray)
Expand All @@ -546,7 +544,7 @@ function parse_array_table(l)::Union{Nothing, ParserError}
return ParserError(ErrArrayTreatedAsDictionary)
end
d_new = TOMLDict()
push!(old, d_new)
push!(old::Vector{Any}, d_new)
push!(l.defined_tables, d_new)
l.active_table = d_new

Expand Down Expand Up @@ -668,41 +666,20 @@ end
# Array #
#########

function push!!(v::Vector, el)
# Since these types are typically non-inferable, they are a big invalidation risk,
# and since it's used by the package-loading infrastructure the cost of invalidation
# is high. Therefore, this is written to reduce the "exposed surface area": e.g., rather
# than writing `T[el]` we write it as `push!(Vector{T}(undef, 1), el)` so that there
# is no ambiguity about what types of objects will be created.
T = eltype(v)
t = typeof(el)
if el isa T || t === T
push!(v, el::T)
return v
elseif T === Union{}
out = Vector{t}(undef, 1)
out[1] = el
return out
else
if T isa Union
newT = Any
else
newT = Union{T, typeof(el)}
end
new = Array{newT}(undef, length(v))
copy!(new, v)
return push!(new, el)
function copyto_typed!(a::Vector{T}, b::Vector) where T
for i in 1:length(b)
a[i] = b[i]::T
end
return nothing
end

function parse_array(l::Parser)::Err{Vector}
function parse_array(l::Parser{Dates})::Err{Vector} where Dates
skip_ws_nl(l)
array = Vector{Union{}}()
array = Vector{Any}()
empty_array = accept(l, ']')
while !empty_array
v = @try parse_value(l)
# TODO: Worth to function barrier this?
array = push!!(array, v)
array = push!(array, v)
# There can be an arbitrary number of newlines and comments before a value and before the closing bracket.
skip_ws_nl(l)
comma = accept(l, ',')
Expand All @@ -712,8 +689,40 @@ function parse_array(l::Parser)::Err{Vector}
return ParserError(ErrExpectedCommaBetweenItemsArray)
end
end
push!(l.static_arrays, array)
return array
# check for static type throughout array
T = !isempty(array) ? typeof(array[1]) : Union{}
for el in array
if typeof(el) != T
T = Any
break
end
end
if T === Any
new = array
elseif T === String
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Bool
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Int64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === UInt64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Float64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Union{}
new = Any[]
elseif (T === TOMLDict) || (T == BigInt) || (T === UInt128) || (T === Int128) || (T <: Vector) ||
(T === Dates.Date) || (T === Dates.Time) || (T === Dates.DateTime)
# do nothing, leave as Vector{Any}
new = array
else @assert false end
push!(l.static_arrays, new)
return new
end


Expand Down Expand Up @@ -1025,10 +1034,9 @@ function parse_datetime(l)
end

function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.DateTime(year, month, day, h, m, s, ms)
return Dates.DateTime(year, month, day, h, m, s, ms)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand All @@ -1039,10 +1047,9 @@ function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) wh
end

function try_return_date(p::Parser{Dates}, year, month, day) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.Date(year, month, day)
return Dates.Date(year, month, day)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand All @@ -1062,10 +1069,9 @@ function parse_local_time(l::Parser)
end

function try_return_time(p::Parser{Dates}, h, m, s, ms) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.Time(h, m, s, ms)
return Dates.Time(h, m, s, ms)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand Down
2 changes: 1 addition & 1 deletion stdlib/TOML/test/values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,6 @@ end
@testset "Array" begin
@test testval("[1,2,3]", Int64[1,2,3])
@test testval("[1.0, 2.0, 3.0]", Float64[1.0, 2.0, 3.0])
@test testval("[1.0, 2.0, 3]", Union{Int64, Float64}[1.0, 2.0, Int64(3)])
Copy link
Member

@KristofferC KristofferC Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I know, what are the rules know when it will promote to Any? IIRC, it used to keep it as a union up to two different eltypes but that has changed now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now specializes for Vector{String}, Vector{Bool}, Vector{(U)Int}, and Vector{Float64}

Anything else is returned as Vector{Any}

@test testval("[1.0, 2.0, 3]", Any[1.0, 2.0, Int64(3)])
@test testval("[1.0, 2, \"foo\"]", Any[1.0, Int64(2), "foo"])
end