Skip to content

Commit

Permalink
TOML: Improve type-stability
Browse files Browse the repository at this point in the history
This changes the output of the TOML parser to provide specialize
`Vector{T}` less aggressively, so that combinatorially expensive types
like `Vector{Vector{Float64}}` or `Vector{Union{Float64,Int64}}` are
instead returned as `Vector{Any}`

Vectors of homogeneous leaf types, like `Vector{Float64}` are still
supported as before.

This change makes the TOML parser fully type-stable, except for its
dynamic usage of Dates.

Co-authored-by: Gabriel Baraldi <[email protected]>
  • Loading branch information
topolarity and gbaraldi committed Jul 29, 2024
1 parent 4dfce5d commit c6c4979
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 47 deletions.
1 change: 0 additions & 1 deletion base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ struct TOMLCache{Dates}
d::Dict{String, CachedTOMLDict}
end
TOMLCache(p::TOML.Parser) = TOMLCache(p, Dict{String, CachedTOMLDict}())
# TODO: Delete this converting constructor once Pkg stops using it
TOMLCache(p::TOML.Parser, d::Dict{String, Dict{String, Any}}) = TOMLCache(p, convert(Dict{String, CachedTOMLDict}, d))

const TOML_CACHE = TOMLCache(TOML.Parser{nothing}())
Expand Down
98 changes: 52 additions & 46 deletions base/toml_parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ mutable struct Parser{Dates}

# Filled in in case we are parsing a file to improve error messages
filepath::Union{String, Nothing}

# Optionally populate with the Dates stdlib to change the type of Date types returned
Dates::Union{Module, Nothing} # TODO: remove once Pkg is updated
end

function Parser{Dates}(str::String; filepath=nothing) where {Dates}
Expand All @@ -106,8 +103,7 @@ function Parser{Dates}(str::String; filepath=nothing) where {Dates}
IdSet{Any}(), # static_arrays
IdSet{TOMLDict}(), # defined_tables
root,
filepath,
nothing
filepath
)
startup(l)
return l
Expand Down Expand Up @@ -495,8 +491,10 @@ function recurse_dict!(l::Parser, d::Dict, dotted_keys::AbstractVector{String},
d = d::TOMLDict
key = dotted_keys[i]
d = get!(TOMLDict, d, key)
if d isa Vector
if d isa Vector{Any}
d = d[end]
elseif d isa Vector
return ParserError(ErrKeyAlreadyHasValue)
end
check && @try check_allowed_add_key(l, d, i == length(dotted_keys))
end
Expand Down Expand Up @@ -538,7 +536,7 @@ function parse_array_table(l)::Union{Nothing, ParserError}
d = @try recurse_dict!(l, l.root, @view(table_key[1:end-1]), false)
k = table_key[end]
old = get!(() -> [], d, k)
if old isa Vector
if old isa Vector{Any}
if old in l.static_arrays
return ParserError(ErrAddArrayToStaticArray)
end
Expand Down Expand Up @@ -668,41 +666,20 @@ end
# Array #
#########

function push!!(v::Vector, el)
# Since these types are typically non-inferable, they are a big invalidation risk,
# and since it's used by the package-loading infrastructure the cost of invalidation
# is high. Therefore, this is written to reduce the "exposed surface area": e.g., rather
# than writing `T[el]` we write it as `push!(Vector{T}(undef, 1), el)` so that there
# is no ambiguity about what types of objects will be created.
T = eltype(v)
t = typeof(el)
if el isa T || t === T
push!(v, el::T)
return v
elseif T === Union{}
out = Vector{t}(undef, 1)
out[1] = el
return out
else
if T isa Union
newT = Any
else
newT = Union{T, typeof(el)}
end
new = Array{newT}(undef, length(v))
copy!(new, v)
return push!(new, el)
function copyto_typed!(a::Vector{T}, b::Vector) where T
for i in 1:length(b)
a[i] = b[i]::T
end
return nothing
end

function parse_array(l::Parser)::Err{Vector}
function parse_array(l::Parser{Dates})::Err{Vector} where Dates
skip_ws_nl(l)
array = Vector{Union{}}()
array = Vector{Any}()
empty_array = accept(l, ']')
while !empty_array
v = @try parse_value(l)
# TODO: Worth to function barrier this?
array = push!!(array, v)
array = push!(array, v)
# There can be an arbitrary number of newlines and comments before a value and before the closing bracket.
skip_ws_nl(l)
comma = accept(l, ',')
Expand All @@ -712,8 +689,40 @@ function parse_array(l::Parser)::Err{Vector}
return ParserError(ErrExpectedCommaBetweenItemsArray)
end
end
push!(l.static_arrays, array)
return array
# check for static type throughout array
T = !isempty(array) ? typeof(array[1]) : Union{}
for el in array
if typeof(el) != T
T = Any
break
end
end
if T === Any
new = array
elseif T === String
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Bool
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Int64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === UInt64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Float64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Union{}
new = Union{}[]
elseif (T === TOMLDict) || (T == BigInt) || (T === UInt128) || (T === Int128) || (T <: Vector) ||
(T === Dates.Date) || (T === Dates.Time) || (T === Dates.DateTime)
# do nothing, leave as Vector{Any}
new = array
else @assert false end
push!(l.static_arrays, new)
return new
end


Expand Down Expand Up @@ -1025,10 +1034,9 @@ function parse_datetime(l)
end

function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.DateTime(year, month, day, h, m, s, ms)
return Dates.DateTime(year, month, day, h, m, s, ms)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand All @@ -1039,10 +1047,9 @@ function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) wh
end

function try_return_date(p::Parser{Dates}, year, month, day) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.Date(year, month, day)
return Dates.Date(year, month, day)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand All @@ -1062,10 +1069,9 @@ function parse_local_time(l::Parser)
end

function try_return_time(p::Parser{Dates}, h, m, s, ms) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.Time(h, m, s, ms)
return Dates.Time(h, m, s, ms)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand Down

0 comments on commit c6c4979

Please sign in to comment.