diff --git a/src/ChemistryFeaturization.jl b/src/ChemistryFeaturization.jl index ada36f69..ccbcd335 100644 --- a/src/ChemistryFeaturization.jl +++ b/src/ChemistryFeaturization.jl @@ -34,8 +34,8 @@ export output_shape, get_value include("featurizations/featurizations.jl") export Featurization -using .Featurization: GraphNodeFeaturization, encode -export GraphNodeFeaturization, encode +using .Featurization: GraphNodeFeaturization, WeaveFeaturization, encode +export GraphNodeFeaturization, WeaveFeaturization, encode export encodable_elements, decode diff --git a/src/features/speciesfeature.jl b/src/features/speciesfeature.jl index fa88872f..7816eb24 100644 --- a/src/features/speciesfeature.jl +++ b/src/features/speciesfeature.jl @@ -48,6 +48,7 @@ function SpeciesFeatureDescriptor(name::String) else # TODO: figure out default binning situation for continuous-valued SFD's #codec = OneHotOneCold(false, ) + codec = DirectCodec(1) end SpeciesFeatureDescriptor{info[:A],typeof(codec)}( name, diff --git a/src/featurizations/weavefeaturization.jl b/src/featurizations/weavefeaturization.jl index 8fe418b3..add48c9d 100644 --- a/src/featurizations/weavefeaturization.jl +++ b/src/featurizations/weavefeaturization.jl @@ -3,12 +3,125 @@ using ..ChemistryFeaturization.AbstractType: AbstractFeaturization using ..ChemistryFeaturization.FeatureDescriptor: AbstractAtomFeatureDescriptor, AbstractPairFeatureDescriptor +using ..ChemistryFeaturization.Codec: DirectCodec +using ..ChemistryFeaturization.Utils struct WeaveFeaturization <: AbstractFeaturization + element_features::Vector{<:AbstractAtomFeatureDescriptor} atom_features::Vector{<:AbstractAtomFeatureDescriptor} + bond_features::Vector{<:BondFeatureDescriptor} pair_features::Vector{<:AbstractPairFeatureDescriptor} end +function WeaveFeaturization(element_feature_list = ["Atomic no"], + species_feature_list = ["degree", + "implicithconnected", + "charge", + "radical_electrons", + "hybridization", + "isaromatic", + "hydrogenconnected"], + bond_feature_list = ["bondorder", "isaromaticbond", "isringbond"]) + elements = ElementFeatureDescriptor.(element_feature_list) + species = SpeciesFeatureDescriptor.(species_feature_list) + bonds = BondFeatureDescriptor.(bond_feature_list) + # pairs = PairFeatureDescriptor.(default_atom_feature_list) + WeaveFeaturization(elements, species, bonds, bonds) +end + +WeaveFeaturization(; kw...) = WeaveFeaturization(values(kw)...) + function encodable_elements(fzn::WeaveFeaturization) - # TODO: implement me! + intersect([encodable_elements(f) for f in fzn.atom_features]...), + intersect([encodable_elements(f) for f in fzn.bond_features]...), + intersect([encodable_elements(f) for f in fzn.pair_features]...) +end + +function atom_features(feat, mol; kw...) + +end + +function bond_feature(bond, mol; kw...) + +end + +function pair_features(feat, mol; kw...) + +end + +const DEEPCHEM_ATOM_SYMBOLS = [ + "C", + "N", + "O", + "S", + "F", + "Si", + "P", + "Cl", + "Br", + "Mg", + "Na", + "Ca", + "Fe", + "As", + "Al", + "I", + "B", + "V", + "K", + "Tl", + "Yb", + "Sb", + "Sn", + "Ag", + "Pd", + "Co", + "Se", + "Ti", + "Zn", + "H", # H? + "Li", + "Ge", + "Cu", + "Au", + "Ni", + "Cd", + "In", + "Mn", + "Zr", + "Cr", + "Pt", + "Hg", + "Pb", + "Unknown" + ] + +# default_atom_feature_list = ["symbol","degree","implicit_valence","formal_charge","radical_electrons","hybridization","aromaticity","total_H_num" ] +# default_bond_feature_list = ["bond_type","isConjugated","isInring"] + +struct FeaturizedWeave + atom_features + bond_features + pair_features +end + +function encode(fzn::WeaveFeaturization, ag::AtomGraph; atom_feature_kwargs = (;), + bond_feature_kwargs = (;), + pair_feature_kwargs = (;)) + sf = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.atom_features) + ef = mapreduce(x -> encode(x, ag, atom_feature_kwargs...), vcat, fzn.element_features) + atom_and_elements = vcat(sf, ef) + bf = cat(map(x -> encode(x, ag, bond_feature_kwargs...), fzn.bond_features)..., dims = 3) + pf = cat(map(x -> encode(x, ag, pair_feature_kwargs...), fzn.pair_features)..., dims = 3) + # Return FeaturizedAtoms here + atom_and_elements, vcat(bf, pf) + # FeaturizedWeave(atom_and_elements, bf, pf) +end + +function Base.show(io::IO, fzn::WeaveFeaturization) + println(io, "WeaveFeaturization(") + println(io, " Species Features: $(map(x -> x.name, fzn.element_features))") + println(io, " Atom Features: $(map(x -> x.name, fzn.atom_features))") + println(io, " Bond Features: $(map(x -> x.name, fzn.bond_features))") + println(io, ")") end diff --git a/src/utils/speciesfeature_utils.jl b/src/utils/speciesfeature_utils.jl index 49c35fb9..4a54ae3f 100644 --- a/src/utils/speciesfeature_utils.jl +++ b/src/utils/speciesfeature_utils.jl @@ -1,5 +1,7 @@ module SpeciesFeatureUtils +using ...ChemistryFeaturization.Codec: DirectCodec + using MolecularGraph # some convenience SFD constructors mapping names of species features to MolecularGraph functions... @@ -50,6 +52,41 @@ const sfd_names_props = Dict( :encodable_elements => mg_elements, :possible_vals => [0, 1, 2], # also not certain this is correct ), + "degree" => Dict( + :A => GraphMol, + :compute_f => nodedegree, + :categorical => true, + :encodable_elements => mg_elements, + :possible_vals => collect(0:10), + ), + "radical_electrons" => Dict( + :A => GraphMol, + :compute_f => multiplicity, + :categorical => true, + :encodable_elements => mg_elements, + :possible_vals => [1, 2, 3], + ), + "multiplicity" => Dict( + :A => GraphMol, + :compute_f => multiplicity, + :categorical => true, + :encodable_elements => mg_elements, + :possible_vals => [1, 2, 3], + ), + "implicithconnected" => Dict( + :A => GraphMol, + :compute_f => implicithconnected, + :categorical => false, + :encodable_elements => mg_elements, + :codec => DirectCodec, + ), + "charge" => Dict( + :A => GraphMol, + :compute_f => charge, + :categorical => false, + :encodable_elements => mg_elements, + :codec => DirectCodec, + ), ) end