Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flesh out WeaveFeaturization #128

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions src/ChemistryFeaturization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ export output_shape, get_value

include("featurizations/featurizations.jl")
export Featurization
using .Featurization: GraphNodeFeaturization, encode
export GraphNodeFeaturization, encode
using .Featurization: GraphNodeFeaturization, WeaveFeaturization, encode
export GraphNodeFeaturization, WeaveFeaturization, encode

export encodable_elements, decode

Expand Down
1 change: 1 addition & 0 deletions src/features/speciesfeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function SpeciesFeatureDescriptor(name::String)
else
# TODO: figure out default binning situation for continuous-valued SFD's
#codec = OneHotOneCold(false, )
codec = DirectCodec(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this actually make sense as a default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it, but I needed something there for the time being. Ideally we would have have to get a number of bins from the user or guess. I don't think OneHotOneCold would work since that implies categorical variable, which continuous variables are not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneHotOneCold works for categorical variables, you just have to specify bins. That's what we've been doing all along since the DirectCodec was just added...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. DataFrames does some kind of trickery to inspect whether a column is categorical or notz but I don't think we need to be doing anything too sophisticated at this stage. I was thinking assuming something is a continuous variable as the default would be reasonable and if a feature is categorical that would need to be described explicitly.

end
SpeciesFeatureDescriptor{info[:A],typeof(codec)}(
name,
Expand Down
115 changes: 114 additions & 1 deletion src/featurizations/weavefeaturization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,125 @@
using ..ChemistryFeaturization.AbstractType: AbstractFeaturization
using ..ChemistryFeaturization.FeatureDescriptor:
AbstractAtomFeatureDescriptor, AbstractPairFeatureDescriptor
using ..ChemistryFeaturization.Codec: DirectCodec
using ..ChemistryFeaturization.Utils

struct WeaveFeaturization <: AbstractFeaturization
element_features::Vector{<:AbstractAtomFeatureDescriptor}
atom_features::Vector{<:AbstractAtomFeatureDescriptor}
bond_features::Vector{<:BondFeatureDescriptor}
pair_features::Vector{<:AbstractPairFeatureDescriptor}
end

function WeaveFeaturization(element_feature_list = ["Atomic no"],
species_feature_list = ["degree",
"implicithconnected",
"charge",
"radical_electrons",
"hybridization",
"isaromatic",
"hydrogenconnected"],
bond_feature_list = ["bondorder", "isaromaticbond", "isringbond"])
elements = ElementFeatureDescriptor.(element_feature_list)
species = SpeciesFeatureDescriptor.(species_feature_list)
bonds = BondFeatureDescriptor.(bond_feature_list)
# pairs = PairFeatureDescriptor.(default_atom_feature_list)
WeaveFeaturization(elements, species, bonds, bonds)
Comment on lines +16 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function WeaveFeaturization(element_feature_list = ["Atomic no"],
species_feature_list = ["degree",
"implicithconnected",
"charge",
"radical_electrons",
"hybridization",
"isaromatic",
"hydrogenconnected"],
bond_feature_list = ["bondorder", "isaromaticbond", "isringbond"])
elements = ElementFeatureDescriptor.(element_feature_list)
species = SpeciesFeatureDescriptor.(species_feature_list)
bonds = BondFeatureDescriptor.(bond_feature_list)
# pairs = PairFeatureDescriptor.(default_atom_feature_list)
WeaveFeaturization(elements, species, bonds, bonds)
function WeaveFeaturization(
element_feature_list = ["Atomic no"],
species_feature_list = [
"degree",
"implicithconnected",
"charge",
"radical_electrons",
"hybridization",
"isaromatic",
"hydrogenconnected",
],
bond_feature_list = ["bondorder", "isaromaticbond", "isringbond"],
)
elements = ElementFeatureDescriptor.(element_feature_list)
species = SpeciesFeatureDescriptor.(species_feature_list)
bonds = BondFeatureDescriptor.(bond_feature_list)
# pairs = PairFeatureDescriptor.(default_atom_feature_list)
WeaveFeaturization(elements, species, bonds, bonds)

end

WeaveFeaturization(; kw...) = WeaveFeaturization(values(kw)...)

function encodable_elements(fzn::WeaveFeaturization)
# TODO: implement me!
intersect([encodable_elements(f) for f in fzn.atom_features]...),
intersect([encodable_elements(f) for f in fzn.bond_features]...),
intersect([encodable_elements(f) for f in fzn.pair_features]...)
end

function atom_features(feat, mol; kw...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

end

function bond_feature(bond, mol; kw...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

end

function pair_features(feat, mol; kw...)

end

const DEEPCHEM_ATOM_SYMBOLS = [
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"As",
"Al",
"I",
"B",
"V",
"K",
"Tl",
"Yb",
"Sb",
"Sn",
"Ag",
"Pd",
"Co",
"Se",
"Ti",
"Zn",
"H", # H?
"Li",
"Ge",
"Cu",
"Au",
"Ni",
"Cd",
"In",
"Mn",
"Zr",
"Cr",
"Pt",
"Hg",
"Pb",
"Unknown"
]
Comment on lines +53 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"As",
"Al",
"I",
"B",
"V",
"K",
"Tl",
"Yb",
"Sb",
"Sn",
"Ag",
"Pd",
"Co",
"Se",
"Ti",
"Zn",
"H", # H?
"Li",
"Ge",
"Cu",
"Au",
"Ni",
"Cd",
"In",
"Mn",
"Zr",
"Cr",
"Pt",
"Hg",
"Pb",
"Unknown"
]
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"As",
"Al",
"I",
"B",
"V",
"K",
"Tl",
"Yb",
"Sb",
"Sn",
"Ag",
"Pd",
"Co",
"Se",
"Ti",
"Zn",
"H", # H?
"Li",
"Ge",
"Cu",
"Au",
"Ni",
"Cd",
"In",
"Mn",
"Zr",
"Cr",
"Pt",
"Hg",
"Pb",
"Unknown",
]


# default_atom_feature_list = ["symbol","degree","implicit_valence","formal_charge","radical_electrons","hybridization","aromaticity","total_H_num" ]
# default_bond_feature_list = ["bond_type","isConjugated","isInring"]

struct FeaturizedWeave
atom_features
bond_features
pair_features
Comment on lines +103 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
atom_features
bond_features
pair_features
atom_features::Any
bond_features::Any
pair_features::Any

end

function encode(fzn::WeaveFeaturization, ag::AtomGraph; atom_feature_kwargs = (;),
bond_feature_kwargs = (;),
pair_feature_kwargs = (;))
sf = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.atom_features)
ef = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.element_features)
atom_and_elements = vcat(sf, ef)
bf = cat(map(x -> encode(x, ag, bond_feature_kwargs...), fzn.bond_features)..., dims = 3)
pf = cat(map(x -> encode(x, ag, pair_feature_kwargs...), fzn.pair_features)..., dims = 3)
# Return FeaturizedAtoms here
atom_and_elements, vcat(bf, pf)
# FeaturizedWeave(atom_and_elements, bf, pf)
Comment on lines +108 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function encode(fzn::WeaveFeaturization, ag::AtomGraph; atom_feature_kwargs = (;),
bond_feature_kwargs = (;),
pair_feature_kwargs = (;))
sf = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.atom_features)
ef = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.element_features)
atom_and_elements = vcat(sf, ef)
bf = cat(map(x -> encode(x, ag, bond_feature_kwargs...), fzn.bond_features)..., dims = 3)
pf = cat(map(x -> encode(x, ag, pair_feature_kwargs...), fzn.pair_features)..., dims = 3)
# Return FeaturizedAtoms here
atom_and_elements, vcat(bf, pf)
# FeaturizedWeave(atom_and_elements, bf, pf)
function encode(
fzn::WeaveFeaturization,
ag::AtomGraph;
atom_feature_kwargs = (;),
bond_feature_kwargs = (;),
pair_feature_kwargs = (;),
)
sf = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.atom_features)
ef = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.element_features)
atom_and_elements = vcat(sf, ef)
bf =
cat(map(x -> encode(x, ag, bond_feature_kwargs...), fzn.bond_features)..., dims = 3)
pf =
cat(map(x -> encode(x, ag, pair_feature_kwargs...), fzn.pair_features)..., dims = 3)
# Return FeaturizedAtoms here
atom_and_elements, vcat(bf, pf)
# FeaturizedWeave(atom_and_elements, bf, pf)

end

function Base.show(io::IO, fzn::WeaveFeaturization)
println(io, "WeaveFeaturization(")
println(io, " Species Features: $(map(x -> x.name, fzn.element_features))")
println(io, " Atom Features: $(map(x -> x.name, fzn.atom_features))")
println(io, " Bond Features: $(map(x -> x.name, fzn.bond_features))")
println(io, ")")
Comment on lines +122 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
println(io, "WeaveFeaturization(")
println(io, " Species Features: $(map(x -> x.name, fzn.element_features))")
println(io, " Atom Features: $(map(x -> x.name, fzn.atom_features))")
println(io, " Bond Features: $(map(x -> x.name, fzn.bond_features))")
println(io, ")")
println(io, "WeaveFeaturization(")
println(io, " Species Features: $(map(x -> x.name, fzn.element_features))")
println(io, " Atom Features: $(map(x -> x.name, fzn.atom_features))")
println(io, " Bond Features: $(map(x -> x.name, fzn.bond_features))")
println(io, ")")

end
37 changes: 37 additions & 0 deletions src/utils/speciesfeature_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module SpeciesFeatureUtils

using ...ChemistryFeaturization.Codec: DirectCodec

using MolecularGraph

# some convenience SFD constructors mapping names of species features to MolecularGraph functions...
Expand Down Expand Up @@ -50,6 +52,41 @@ const sfd_names_props = Dict(
:encodable_elements => mg_elements,
:possible_vals => [0, 1, 2], # also not certain this is correct
),
"degree" => Dict(
:A => GraphMol,
:compute_f => nodedegree,
:categorical => true,
:encodable_elements => mg_elements,
:possible_vals => collect(0:10),
),
"radical_electrons" => Dict(
:A => GraphMol,
:compute_f => multiplicity,
:categorical => true,
:encodable_elements => mg_elements,
:possible_vals => [1, 2, 3],
),
"multiplicity" => Dict(
:A => GraphMol,
:compute_f => multiplicity,
:categorical => true,
:encodable_elements => mg_elements,
:possible_vals => [1, 2, 3],
),
"implicithconnected" => Dict(
:A => GraphMol,
:compute_f => implicithconnected,
:categorical => false,
:encodable_elements => mg_elements,
:codec => DirectCodec,
),
"charge" => Dict(
:A => GraphMol,
:compute_f => charge,
:categorical => false,
:encodable_elements => mg_elements,
:codec => DirectCodec,
),
)

end