Skip to content

Commit

Permalink
Merge branch 'antipattern'
Browse files Browse the repository at this point in the history
  • Loading branch information
singularitti committed Apr 21, 2020
2 parents fcee48c + 57a566a commit ca24c34
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 42 deletions.
5 changes: 3 additions & 2 deletions src/Collections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ end

# Miscellaneous
if VERSION >= v"1.3"
(eos::EquationOfState)(prop::PhysicalProperty) = v -> _evaluate(eos, prop, v)
(eos::EquationOfState)(property::PhysicalProperty) = v -> _evaluate(eos, property, v)
else
for T in (
:Murnaghan,
Expand All @@ -688,14 +688,15 @@ else
:Shanker,
)
eval(quote
(eos::$T)(prop::PhysicalProperty) = v -> _evaluate(eos, prop, v)
(eos::$T)(property::PhysicalProperty) = v -> _evaluate(eos, property, v)
end)
end # Julia 1.0-1.2 does not support adding methods to abstract types.
end

Base.:(==)(x::T, y::T) where {T<:EquationOfState} = all(fieldvalues(x) .== fieldvalues(y))

Base.eltype(::FieldValues{<:EquationOfState{T}}) where {T} = T
Base.eltype(::Type{<:EquationOfState{T}}) where {T} = T

function Base.getproperty(eos::EquationOfState, name::Symbol)
if name (:bp0, :bd0)
Expand Down
2 changes: 1 addition & 1 deletion src/Find.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Find a volume which leads to the given pressure, energy, or bulk modulus based o
root-finding process.
"""
findvolume(f, y, x0, method) = find_zero(v -> f(v) - y, x0, method)
function findvolume(f, y, x0::Union{AbstractVector,Tuple})
function findvolume(f, y, x0)
for T in [subtypes(AbstractBisection); subtypes(AbstractAlefeldPotraShi)]
@info("Using method \"$T\"...")
try
Expand Down
3 changes: 2 additions & 1 deletion src/LinearFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ This module provides some linear fitting methods.
module LinearFitting

using LinearAlgebra: dot
using Polynomials: polyder, polyfit, degree, coeffs, Poly
using Polynomials: polyder, degree, coeffs, Poly
using Polynomials.PolyCompat: polyfit

using ..FiniteStrains

Expand Down
75 changes: 37 additions & 38 deletions src/NonlinearFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,48 @@ Fit an equation of state using least-squares fitting method (with the Levenberg-
# Arguments
- `eos::EquationOfState`: a trial equation of state. If it has units, `xdata` and `ydata` must also have.
- `prop::PhysicalProperty`: a `PhysicalProperty` instance. If `Energy`, fit ``E(V)``; if `Pressure`, fit ``P(V)``; if `BulkModulus`, fit ``B(V)``.
- `xdata::AbstractVector`: a vector of volumes (``V``), with(out) units.
- `ydata::AbstractVector`: a vector of energies (``E``), pressures (``P``), or bulk moduli (``B``), with(out) units. It must be consistent with `prop`.
- `xdata::AbstractArray`: an array of volumes (``V``), with(out) units.
- `ydata::AbstractArray`: an array of energies (``E``), pressures (``P``), or bulk moduli (``B``), with(out) units. It must be consistent with `prop`.
- `debug::Bool=false`: if `true`, then an `LsqFit.LsqFitResult` is returned, containing estimated Jacobian, residuals, etc.; if `false`, a fitted `EquationOfState` is returned. The default value is `false`.
- `kwargs`: the rest keyword arguments are the same as that of `LsqFit.curve_fit`. See its [documentation](https://github.com/JuliaNLSolvers/LsqFit.jl/blob/master/README.md)
and [tutorial](https://julianlsolvers.github.io/LsqFit.jl/latest/tutorial/).
"""
function lsqfit(
f::Function,
xdata::AbstractVector{<:Real},
ydata::AbstractVector{<:Real};
debug = false,
kwargs...,
)
T = constructorof(typeof(f.eos)) # Get the `UnionAll` type
model = (x, p) -> map(T(p...)(f.prop), x)
fitted = curve_fit(
model,
float(xdata), # Convert `xdata` elements to floats
float(ydata), # Convert `ydata` elements to floats
float.(fieldvalues(f.eos)); # TODO: What if these floats are different types?
kwargs...,
)
return debug ? fitted : T(fitted.param...)
end # function lsqfit
function lsqfit(
f::Function,
xdata::AbstractVector{<:AbstractQuantity},
ydata::AbstractVector{<:AbstractQuantity};
kwargs...,
function lsqfit(f, xdata, ydata; kwargs...)
eos, property = fieldvalues(f)
T = constructorof(typeof(eos)) # Get the `UnionAll` type
params, xdata, ydata = _preprocess(eos, xdata, ydata)
model = (x, p) -> map(T(p...)(property), x)
fit = curve_fit(model, xdata, ydata, params; kwargs...)
return _postprocess(T(fit.param...), eos)
end # function lsqfit

struct _Data{S,T}
data::T
end
_Data(data::T) where {T} = _Data{eltype(data),T}(data)

_preprocess(eos, xdata, ydata) = _preprocess(_Data(eos), _Data(xdata), _Data(ydata)) # Holy trait
_preprocess(eos::_Data{<:Real}, xdata::_Data{<:Real}, ydata::_Data{<:Real}) =
float.(fieldvalues(eos.data)), float(xdata.data), float(ydata.data)
function _preprocess(
eos::_Data{<:AbstractQuantity},
xdata::_Data{<:AbstractQuantity},
ydata::_Data{<:AbstractQuantity},
)
T = constructorof(typeof(f.eos)) # Get the `UnionAll` type
values = fieldvalues(f.eos)
values = fieldvalues(eos.data)
original_units = unit.(values) # Keep a record of `eos`'s units
g = x -> map(ustrip upreferred, x) # Convert to preferred units and strip the unit
trial_params = g.(values)
result = lsqfit(T(trial_params...)(f.prop), g.(xdata), g.(ydata); kwargs...)
if result isa EquationOfState # i.e., if `debug = false` and no error is thrown
return T((
x * upreferred(u) |> u for (x, u) in zip(fieldvalues(result), original_units)
)...) # Convert back to original `eos`'s units
else
return result
end
end # function lsqfit
f = x -> map(float ustrip upreferred, x) # Convert to preferred units and strip the unit
return map(f, (values, xdata.data, ydata.data))
end # function _preprocess

_postprocess(eos, trial_eos) = _postprocess(eos, _Data(trial_eos)) # Holy trait
_postprocess(eos, trial_eos::_Data{<:Real}) = eos
function _postprocess(eos, trial_eos::_Data{<:AbstractQuantity})
T = constructorof(typeof(trial_eos.data)) # Get the `UnionAll` type
original_units = unit.(fieldvalues(trial_eos.data)) # Keep a record of `eos`'s units
return T((
x * upreferred(u) |> u for (x, u) in zip(fieldvalues(eos), original_units)
)...) # Convert back to original `eos`'s units
end # function _postprocess

end
2 changes: 2 additions & 0 deletions test/Collections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ using EquationsOfState.Collections
BreenanStacey{Quantity{Float64}}
@test typeof(BreenanStacey(1 * u"nm^3", 2 * u"GPa", 3 // 1, 0 * u"eV")) ===
BreenanStacey{Quantity{Rational{Int}}}
@test BirchMurnaghan3rd(1 * u"angstrom^3", 2 * u"GPa", 4 // 1, 3 * u"eV").b′0 isa
DimensionlessQuantity
end

@testset "Test default EOS parameter `e0` and promotion" begin
Expand Down

0 comments on commit ca24c34

Please sign in to comment.