Skip to content

Commit

Permalink
Merge pull request #34 from MineralsCloud/feature/trait
Browse files Browse the repository at this point in the history
Refactor module `NonlinearFitting` using traits
  • Loading branch information
singularitti authored Sep 30, 2019
2 parents 44de5db + 27cc892 commit 5f3a054
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 123 deletions.
117 changes: 14 additions & 103 deletions src/Collections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ julia>
module Collections

using InteractiveUtils
using Unitful: AbstractQuantity, @u_str, uconvert, NoUnits, 𝐋, 𝐌, 𝐓, Dimension, Dimensions
using Unitful: AbstractQuantity, @u_str, Dimension, Dimensions
import Unitful
using UnitfulAstro

Expand Down Expand Up @@ -319,7 +319,8 @@ function apply(::EnergyForm, eos::BirchMurnaghan4th, v)

f = (cbrt(v0 / v)^2 - 1) / 2
h = b0 * bpp0 + bp0^2
return e0 + 3 / 8 * v0 * b0 * f^2 * ((9h - 63bp0 + 143) * f^2 + 12 * (bp0 - 4) * f + 12)
return e0 +
3 / 8 * v0 * b0 * f^2 * ((9h - 63bp0 + 143) * f^2 + 12 * (bp0 - 4) * f + 12)
end
"""
apply(EnergyForm(), eos::PoirierTarantola2nd, v)
Expand Down Expand Up @@ -664,107 +665,17 @@ fieldvalues(eos::EquationOfState) = [getfield(eos, i) for i in 1:nfields(eos)]

Base.eltype(T::Type{<:EquationOfState}) = promote_type(T.types...)

Unitful.promote_unit(::S, ::T) where {S<:Unitful.EnergyUnits,T<:Unitful.EnergyUnits} = u"eV"
Unitful.promote_unit(::S, ::T) where {S<:Unitful.LengthUnits,T<:Unitful.LengthUnits} = u"angstrom"
Unitful.promote_unit(::S, ::T) where {S<:Unitful.PressureUnits,T<:Unitful.PressureUnits} = u"eV/angstrom^3"

Unitful.upreferred(::Dimensions{(Dimension{:Length}(2//1),Dimension{:Mass}(1//1),Dimension{:Time}(-2//1))}) = u"eV"
Unitful.upreferred(::Dimensions{(Dimension{:Length}(3//1),)}) = u"angstrom^3"
Unitful.upreferred(::Dimensions{(Dimension{:Length}(-1//1),Dimension{:Mass}(1//1),Dimension{:Time}(-2//1))}) = u"eV/angstrom^3"
Unitful.upreferred(eos::Murnaghan{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
Murnaghan(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(NoUnits, eos.bp0),
uconvert(u"eV", eos.e0),
)
Unitful.upreferred(eos::BirchMurnaghan2nd{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
BirchMurnaghan2nd(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(u"eV", eos.e0),
)
Unitful.upreferred(eos::BirchMurnaghan3rd{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
BirchMurnaghan3rd(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(NoUnits, eos.bp0),
uconvert(u"eV", eos.e0),
)
Unitful.upreferred(eos::BirchMurnaghan4th{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
BirchMurnaghan4th(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(NoUnits, eos.bp0),
uconvert(u"angstrom^3/eV", eos.bpp0) * uconvert(u"eV", eos.e0),
)
Unitful.upreferred(eos::PoirierTarantola2nd{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
PoirierTarantola2nd(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(u"eV", eos.e0),
)
Unitful.upreferred(eos::PoirierTarantola3rd{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
PoirierTarantola3rd(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(NoUnits, eos.bp0),
uconvert(u"eV", eos.e0),
)
Unitful.upreferred(eos::PoirierTarantola4th{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
PoirierTarantola4th(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(NoUnits, eos.bp0),
uconvert(u"angstrom^3/eV", eos.bpp0) * uconvert(u"eV", eos.e0),
)
Unitful.upreferred(eos::Vinet{
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
<:AbstractQuantity,
}) =
Vinet(
uconvert(u"angstrom^3", eos.v0),
uconvert(u"eV/angstrom^3", eos.b0),
uconvert(NoUnits, eos.bp0),
uconvert(u"eV", eos.e0),
)
Unitful.upreferred(::Dimensions{(
Dimension{:Length}(2 // 1),
Dimension{:Mass}(1 // 1),
Dimension{:Time}(-2 // 1),
)}) = u"eV"
Unitful.upreferred(::Dimensions{(Dimension{:Length}(3 // 1),)}) = u"angstrom^3"
Unitful.upreferred(::Dimensions{(
Dimension{:Length}(-1 // 1),
Dimension{:Mass}(1 // 1),
Dimension{:Time}(-2 // 1),
)}) = u"eV/angstrom^3"
# =============================== Miscellaneous ============================== #

end
6 changes: 5 additions & 1 deletion src/Find.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ function findvolume(
eos::EquationOfState,
y::Real,
domain::Union{AbstractVector,Tuple},
method::Union{AbstractNonBracketing,AbstractHalleyLikeMethod,AbstractNewtonLikeMethod},
method::Union{
AbstractNonBracketing,
AbstractHalleyLikeMethod,
AbstractNewtonLikeMethod,
},
)
f(v) = apply(form, eos, v) - y
return find_zero(f, median(domain), method)
Expand Down
46 changes: 27 additions & 19 deletions src/NonlinearFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ julia>
module NonlinearFitting

using LsqFit: curve_fit
using Unitful
import Unitful: AbstractQuantity, 𝐋, 𝐌, 𝐓
using Unitful: AbstractQuantity, upreferred, ustrip, unit, uconvert

import ..EquationForm
using ..Collections

export lsqfit

abstract type UnitTrait end
struct NoUnit <: UnitTrait end
struct HasUnit <: UnitTrait end
# This idea is borrowed from [SimpleTraits.jl](https://github.com/mauro3/SimpleTraits.jl/blob/master/src/SimpleTraits.jl).
abstract type Trait end
abstract type Not{T<:Trait} <: Trait end
struct HasUnit <: Trait end

_traitfn(T::Type{<:Number}) = NoUnit
_traitfn(T::Type{<:AbstractQuantity}) = HasUnit
_unit_trait(T::Type{<:Real}) = Not{HasUnit}
_unit_trait(T::Type{<:AbstractQuantity}) = HasUnit

"""
lsqfit(form, eos, xdata, ydata; debug = false, kwargs...)
Expand All @@ -48,41 +48,49 @@ function lsqfit(
kwargs...,
)
T = eltype(eos)
return lsqfit(form, eos, xdata, ydata, _traitfn(T), kwargs...)
return lsqfit(_unit_trait(T), form, eos, xdata, ydata, kwargs...)
end # function lsqfit
function lsqfit(
::Type{Not{HasUnit}},
form::EquationForm,
eos::EquationOfState,
xdata::AbstractVector,
ydata::AbstractVector,
trait::Type{NoUnit};
ydata::AbstractVector;
debug = false,
kwargs...,
)
T = promote_type(eltype(eos), eltype(xdata), eltype(ydata), Float64)
E = typeof(eos).name.wrapper
model(x, p) = map(apply(form, E(p...)), x)
fitted = curve_fit(model, T.(xdata), T.(ydata), T.(Collections.fieldvalues(eos)); kwargs...)
model = (x, p) -> map(apply(form, E(p...)), x)
fitted = curve_fit(
model,
T.(xdata),
T.(ydata),
T.(Collections.fieldvalues(eos)),
kwargs...,
)
return debug ? fitted : E(fitted.param...)
end # function lsqfit
function lsqfit(
::Type{HasUnit},
form::EquationForm,
eos::EquationOfState,
xdata::AbstractVector,
ydata::AbstractVector,
trait::Type{HasUnit};
ydata::AbstractVector;
kwargs...,
)
E = typeof(eos).name.wrapper
units = unit.(Collections.fieldvalues(eos))
trial_params = map(ustrip, Collections.fieldvalues(upreferred(eos)))
xdata, ydata = map(ustrip upreferred, xdata), map(ustrip upreferred, ydata)
values = Collections.fieldvalues(eos)
original_units = map(unit, values)
trial_params, xdata, ydata = [map(ustrip upreferred, x) for x in (values, xdata, ydata)]
result = lsqfit(form, E(trial_params...), xdata, ydata, kwargs...)
if result isa EquationOfState
E(
[uconvert(u, Collections.fieldvalues(result)[i] * upreferred(u)) for (i, u) in enumerate(units)]...
data = Collections.fieldvalues(result)
return E(
[uconvert(u, data[i] * upreferred(u)) for (i, u) in enumerate(original_units)]...
)
end
return result
end # function lsqfit

end

0 comments on commit 5f3a054

Please sign in to comment.