From e3dfd4a9bcf88b3f097685e9b6d1a8eb5d6d5b34 Mon Sep 17 00:00:00 2001 From: Qi Zhang Date: Mon, 30 Sep 2019 02:36:16 -0400 Subject: [PATCH 1/5] Implement trait `HasUnit` following SimpleTraits.jl --- src/NonlinearFitting.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/NonlinearFitting.jl b/src/NonlinearFitting.jl index ca390dc..9c1f712 100644 --- a/src/NonlinearFitting.jl +++ b/src/NonlinearFitting.jl @@ -20,12 +20,13 @@ using ..Collections export lsqfit -abstract type UnitTrait end -struct NoUnit <: UnitTrait end -struct HasUnit <: UnitTrait end +# This idea is borrowed from [SimpleTraits.jl](https://github.com/mauro3/SimpleTraits.jl/blob/master/src/SimpleTraits.jl). +abstract type Trait end +abstract type Not{T<:Trait} <: Trait end +struct HasUnit <: Trait end -_traitfn(T::Type{<:Number}) = NoUnit -_traitfn(T::Type{<:AbstractQuantity}) = HasUnit +_unit_trait(T::Type{<:Real}) = Not{HasUnit} +_unit_trait(T::Type{<:AbstractQuantity}) = HasUnit """ lsqfit(form, eos, xdata, ydata; debug = false, kwargs...) @@ -48,29 +49,29 @@ function lsqfit( kwargs..., ) T = eltype(eos) - return lsqfit(form, eos, xdata, ydata, _traitfn(T), kwargs...) + return lsqfit(_unit_trait(T), form, eos, xdata, ydata, kwargs...) end # function lsqfit function lsqfit( + ::Type{Not{HasUnit}}, form::EquationForm, eos::EquationOfState, xdata::AbstractVector, - ydata::AbstractVector, - trait::Type{NoUnit}; + ydata::AbstractVector; debug = false, kwargs..., ) T = promote_type(eltype(eos), eltype(xdata), eltype(ydata), Float64) E = typeof(eos).name.wrapper model(x, p) = map(apply(form, E(p...)), x) - fitted = curve_fit(model, T.(xdata), T.(ydata), T.(Collections.fieldvalues(eos)); kwargs...) + fitted = curve_fit(model, T.(xdata), T.(ydata), T.(Collections.fieldvalues(eos)), kwargs...) return debug ? fitted : E(fitted.param...) end # function lsqfit function lsqfit( + ::Type{HasUnit}, form::EquationForm, eos::EquationOfState, xdata::AbstractVector, - ydata::AbstractVector, - trait::Type{HasUnit}; + ydata::AbstractVector; kwargs..., ) E = typeof(eos).name.wrapper From fcc63d9c502e358c77d64caaf1ebc0a4ba73dcad Mon Sep 17 00:00:00 2001 From: Qi Zhang Date: Mon, 30 Sep 2019 03:02:18 -0400 Subject: [PATCH 2/5] refactor: Simplify `lsqfit` --- src/NonlinearFitting.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/NonlinearFitting.jl b/src/NonlinearFitting.jl index 9c1f712..d54182c 100644 --- a/src/NonlinearFitting.jl +++ b/src/NonlinearFitting.jl @@ -75,15 +75,17 @@ function lsqfit( kwargs..., ) E = typeof(eos).name.wrapper - units = unit.(Collections.fieldvalues(eos)) - trial_params = map(ustrip, Collections.fieldvalues(upreferred(eos))) - xdata, ydata = map(ustrip ∘ upreferred, xdata), map(ustrip ∘ upreferred, ydata) + values = Collections.fieldvalues(eos) + original_units = map(unit, values) + trial_params, xdata, ydata = [map(ustrip ∘ upreferred, x) for x in (values, xdata, ydata)] result = lsqfit(form, E(trial_params...), xdata, ydata, kwargs...) if result isa EquationOfState - E( - [uconvert(u, Collections.fieldvalues(result)[i] * upreferred(u)) for (i, u) in enumerate(units)]... + data = Collections.fieldvalues(result) + return E( + [uconvert(u, data[i] * upreferred(u)) for (i, u) in enumerate(original_units)]... ) end + return result end # function lsqfit end From 0e70bf01fc3653bfdd13bc87162396b74cbfd421 Mon Sep 17 00:00:00 2001 From: Qi Zhang Date: Mon, 30 Sep 2019 03:03:29 -0400 Subject: [PATCH 3/5] Remove unneeded `promote_unit` & `upreferred` --- src/Collections.jl | 98 ---------------------------------------------- 1 file changed, 98 deletions(-) diff --git a/src/Collections.jl b/src/Collections.jl index e76bbb2..325c71f 100644 --- a/src/Collections.jl +++ b/src/Collections.jl @@ -664,107 +664,9 @@ fieldvalues(eos::EquationOfState) = [getfield(eos, i) for i in 1:nfields(eos)] Base.eltype(T::Type{<:EquationOfState}) = promote_type(T.types...) -Unitful.promote_unit(::S, ::T) where {S<:Unitful.EnergyUnits,T<:Unitful.EnergyUnits} = u"eV" -Unitful.promote_unit(::S, ::T) where {S<:Unitful.LengthUnits,T<:Unitful.LengthUnits} = u"angstrom" -Unitful.promote_unit(::S, ::T) where {S<:Unitful.PressureUnits,T<:Unitful.PressureUnits} = u"eV/angstrom^3" - Unitful.upreferred(::Dimensions{(Dimension{:Length}(2//1),Dimension{:Mass}(1//1),Dimension{:Time}(-2//1))}) = u"eV" Unitful.upreferred(::Dimensions{(Dimension{:Length}(3//1),)}) = u"angstrom^3" Unitful.upreferred(::Dimensions{(Dimension{:Length}(-1//1),Dimension{:Mass}(1//1),Dimension{:Time}(-2//1))}) = u"eV/angstrom^3" -Unitful.upreferred(eos::Murnaghan{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - Murnaghan( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(NoUnits, eos.bp0), - uconvert(u"eV", eos.e0), - ) -Unitful.upreferred(eos::BirchMurnaghan2nd{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - BirchMurnaghan2nd( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(u"eV", eos.e0), - ) -Unitful.upreferred(eos::BirchMurnaghan3rd{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - BirchMurnaghan3rd( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(NoUnits, eos.bp0), - uconvert(u"eV", eos.e0), - ) -Unitful.upreferred(eos::BirchMurnaghan4th{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - BirchMurnaghan4th( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(NoUnits, eos.bp0), - uconvert(u"angstrom^3/eV", eos.bpp0) * uconvert(u"eV", eos.e0), - ) -Unitful.upreferred(eos::PoirierTarantola2nd{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - PoirierTarantola2nd( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(u"eV", eos.e0), - ) -Unitful.upreferred(eos::PoirierTarantola3rd{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - PoirierTarantola3rd( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(NoUnits, eos.bp0), - uconvert(u"eV", eos.e0), - ) -Unitful.upreferred(eos::PoirierTarantola4th{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - PoirierTarantola4th( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(NoUnits, eos.bp0), - uconvert(u"angstrom^3/eV", eos.bpp0) * uconvert(u"eV", eos.e0), - ) -Unitful.upreferred(eos::Vinet{ - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, - <:AbstractQuantity, -}) = - Vinet( - uconvert(u"angstrom^3", eos.v0), - uconvert(u"eV/angstrom^3", eos.b0), - uconvert(NoUnits, eos.bp0), - uconvert(u"eV", eos.e0), - ) # =============================== Miscellaneous ============================== # end From 6a93cd8e2d0b0e241227f4dea4cd1b1f7c13ab7a Mon Sep 17 00:00:00 2001 From: Qi Zhang Date: Mon, 30 Sep 2019 03:08:58 -0400 Subject: [PATCH 4/5] Fix imports --- src/Collections.jl | 2 +- src/NonlinearFitting.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Collections.jl b/src/Collections.jl index 325c71f..77dc7ea 100644 --- a/src/Collections.jl +++ b/src/Collections.jl @@ -12,7 +12,7 @@ julia> module Collections using InteractiveUtils -using Unitful: AbstractQuantity, @u_str, uconvert, NoUnits, 𝐋, 𝐌, 𝐓, Dimension, Dimensions +using Unitful: AbstractQuantity, @u_str, Dimension, Dimensions import Unitful using UnitfulAstro diff --git a/src/NonlinearFitting.jl b/src/NonlinearFitting.jl index d54182c..cf2eabe 100644 --- a/src/NonlinearFitting.jl +++ b/src/NonlinearFitting.jl @@ -12,8 +12,7 @@ julia> module NonlinearFitting using LsqFit: curve_fit -using Unitful -import Unitful: AbstractQuantity, 𝐋, 𝐌, 𝐓 +using Unitful: AbstractQuantity, upreferred, ustrip, unit, uconvert import ..EquationForm using ..Collections From 27cc8925f80018bd62d7a546d5f65790886eb81f Mon Sep 17 00:00:00 2001 From: Qi Zhang Date: Mon, 30 Sep 2019 03:14:02 -0400 Subject: [PATCH 5/5] style: Restyle code --- src/Collections.jl | 17 +++++++++++++---- src/Find.jl | 6 +++++- src/NonlinearFitting.jl | 10 ++++++++-- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/Collections.jl b/src/Collections.jl index 77dc7ea..a80acbb 100644 --- a/src/Collections.jl +++ b/src/Collections.jl @@ -319,7 +319,8 @@ function apply(::EnergyForm, eos::BirchMurnaghan4th, v) f = (cbrt(v0 / v)^2 - 1) / 2 h = b0 * bpp0 + bp0^2 - return e0 + 3 / 8 * v0 * b0 * f^2 * ((9h - 63bp0 + 143) * f^2 + 12 * (bp0 - 4) * f + 12) + return e0 + + 3 / 8 * v0 * b0 * f^2 * ((9h - 63bp0 + 143) * f^2 + 12 * (bp0 - 4) * f + 12) end """ apply(EnergyForm(), eos::PoirierTarantola2nd, v) @@ -664,9 +665,17 @@ fieldvalues(eos::EquationOfState) = [getfield(eos, i) for i in 1:nfields(eos)] Base.eltype(T::Type{<:EquationOfState}) = promote_type(T.types...) -Unitful.upreferred(::Dimensions{(Dimension{:Length}(2//1),Dimension{:Mass}(1//1),Dimension{:Time}(-2//1))}) = u"eV" -Unitful.upreferred(::Dimensions{(Dimension{:Length}(3//1),)}) = u"angstrom^3" -Unitful.upreferred(::Dimensions{(Dimension{:Length}(-1//1),Dimension{:Mass}(1//1),Dimension{:Time}(-2//1))}) = u"eV/angstrom^3" +Unitful.upreferred(::Dimensions{( + Dimension{:Length}(2 // 1), + Dimension{:Mass}(1 // 1), + Dimension{:Time}(-2 // 1), +)}) = u"eV" +Unitful.upreferred(::Dimensions{(Dimension{:Length}(3 // 1),)}) = u"angstrom^3" +Unitful.upreferred(::Dimensions{( + Dimension{:Length}(-1 // 1), + Dimension{:Mass}(1 // 1), + Dimension{:Time}(-2 // 1), +)}) = u"eV/angstrom^3" # =============================== Miscellaneous ============================== # end diff --git a/src/Find.jl b/src/Find.jl index dfd2317..2e6993f 100644 --- a/src/Find.jl +++ b/src/Find.jl @@ -46,7 +46,11 @@ function findvolume( eos::EquationOfState, y::Real, domain::Union{AbstractVector,Tuple}, - method::Union{AbstractNonBracketing,AbstractHalleyLikeMethod,AbstractNewtonLikeMethod}, + method::Union{ + AbstractNonBracketing, + AbstractHalleyLikeMethod, + AbstractNewtonLikeMethod, + }, ) f(v) = apply(form, eos, v) - y return find_zero(f, median(domain), method) diff --git a/src/NonlinearFitting.jl b/src/NonlinearFitting.jl index cf2eabe..fa4df01 100644 --- a/src/NonlinearFitting.jl +++ b/src/NonlinearFitting.jl @@ -61,8 +61,14 @@ function lsqfit( ) T = promote_type(eltype(eos), eltype(xdata), eltype(ydata), Float64) E = typeof(eos).name.wrapper - model(x, p) = map(apply(form, E(p...)), x) - fitted = curve_fit(model, T.(xdata), T.(ydata), T.(Collections.fieldvalues(eos)), kwargs...) + model = (x, p) -> map(apply(form, E(p...)), x) + fitted = curve_fit( + model, + T.(xdata), + T.(ydata), + T.(Collections.fieldvalues(eos)), + kwargs..., + ) return debug ? fitted : E(fitted.param...) end # function lsqfit function lsqfit(