Skip to content

Commit

Permalink
fix yapf/flake8/mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
JoseAntonioSiguenza committed Dec 8, 2023
1 parent db995bd commit 05c7734
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 60 deletions.
115 changes: 57 additions & 58 deletions deepchem/models/torch_models/acnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from deepchem.models.torch_models.layers import AtomicConv
from deepchem.data import Dataset
from deepchem.models.losses import L2Loss
from deepchem.metrics import to_one_hot

from typing import List, Union, Callable, Optional, Sequence
from deepchem.utils.typing import OneOrMany, ActivationFn, LossFn
from typing import List, Callable, Optional, Sequence, Tuple, Iterable
from deepchem.utils.typing import OneOrMany, ActivationFn


class AtomConvModel(TorchModel):
Expand Down Expand Up @@ -59,32 +58,31 @@ class AtomConvModel(TorchModel):
>>> preds = atomic_convnet.predict(train)
"""

def __init__(
self,
n_tasks: int,
frag1_num_atoms: int = 70,
frag2_num_atoms: int = 634,
complex_num_atoms: int = 701,
max_num_neighbors: int = 12,
batch_size: int = 24,
atom_types: Sequence[float] = [
6, 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.,
-1.
],
radial: Sequence[Sequence[float]] = [[
1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
8.0, 8.5, 9.0, 9.5, 10.0, 10.5, 11.0, 11.5, 12.0
], [0.0, 4.0, 8.0], [0.4]],
layer_sizes=[100],
weight_init_stddevs: OneOrMany[float] = 0.02,
bias_init_consts: OneOrMany[float] = 1.0,
weight_decay_penalty: float = 0.0,
weight_decay_penalty_type: str = "l2",
dropouts: OneOrMany[float] = 0.5,
activation_fns: OneOrMany[ActivationFn] = ['relu'],
residual: bool = False,
learning_rate=0.001,
**kwargs) -> None:
def __init__(self,
n_tasks: int,
frag1_num_atoms: int = 70,
frag2_num_atoms: int = 634,
complex_num_atoms: int = 701,
max_num_neighbors: int = 12,
batch_size: int = 24,
atom_types: Sequence[float] = [
6, 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35.,
53., -1.
],
radial: Sequence[Sequence[float]] = [[
1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0,
7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 10.5, 11.0, 11.5, 12.0
], [0.0, 4.0, 8.0], [0.4]],
layer_sizes=[100],
weight_init_stddevs: OneOrMany[float] = 0.02,
bias_init_consts: OneOrMany[float] = 1.0,
weight_decay_penalty: float = 0.0,
weight_decay_penalty_type: str = "l2",
dropouts: OneOrMany[float] = 0.5,
activation_fns: OneOrMany[ActivationFn] = ['relu'],
residual: bool = False,
learning_rate=0.001,
**kwargs) -> None:
"""TorchModel wrapper for ACNN
Parameters
Expand Down Expand Up @@ -142,20 +140,20 @@ def __init__(
self.atom_types = atom_types

self.model = AtomicConv(n_tasks=n_tasks,
frag1_num_atoms=frag1_num_atoms,
frag2_num_atoms=frag2_num_atoms,
complex_num_atoms=complex_num_atoms,
max_num_neighbors=max_num_neighbors,
batch_size=batch_size,
atom_types=atom_types,
radial=radial,
layer_sizes=layer_sizes,
weight_init_stddevs=weight_init_stddevs,
bias_init_consts=bias_init_consts,
dropouts=dropouts,
activation_fns=activation_fns,
residual=residual,
learning_rate=learning_rate)
frag1_num_atoms=frag1_num_atoms,
frag2_num_atoms=frag2_num_atoms,
complex_num_atoms=complex_num_atoms,
max_num_neighbors=max_num_neighbors,
batch_size=batch_size,
atom_types=atom_types,
radial=radial,
layer_sizes=layer_sizes,
weight_init_stddevs=weight_init_stddevs,
bias_init_consts=bias_init_consts,
dropouts=dropouts,
activation_fns=activation_fns,
residual=residual,
learning_rate=learning_rate)

regularization_loss: Optional[Callable]

Expand All @@ -170,21 +168,22 @@ def __init__(
else:
regularization_loss = None


loss = L2Loss()

super(AtomConvModel, self).__init__(self.model,
loss=loss,
batch_size=batch_size,
regularization_loss=regularization_loss,
**kwargs)

def default_generator(self,
dataset: Dataset,
epochs: int = 1,
mode: str = 'fit',
deterministic: bool =True,
pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
super(AtomConvModel,
self).__init__(self.model,
loss=loss,
batch_size=batch_size,
regularization_loss=regularization_loss,
**kwargs)

def default_generator(
self,
dataset: Dataset,
epochs: int = 1,
mode: str = 'fit',
deterministic: bool = True,
pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
"""Convert a dataset into the tensors needed for learning.
Parameters
Expand Down Expand Up @@ -213,7 +212,7 @@ def replace_atom_types(z):
Parameters
----------
z: list
Atom types learned from the model.
Atom types learned from the model.
Returns
-------
Expand Down Expand Up @@ -297,4 +296,4 @@ def replace_atom_types(z):

y_b = np.reshape(y_b, newshape=(batch_size, 1))

yield (inputs, [y_b], [w_b])
yield (inputs, [y_b], [w_b])
4 changes: 2 additions & 2 deletions deepchem/models/torch_models/tests/test_acnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_atomic_convolution_model():
def test_atomic_convolution_model_variable():
"""A simple test that initializes and fits an AtomConvModel on variable input size."""
from deepchem.models.torch_models import AtomConvModel
frag1_num_atoms = 100 # atoms for ligand
frag2_num_atoms = 1200 # atoms for protein
frag1_num_atoms = 100 # atoms for ligand
frag2_num_atoms = 1200 # atoms for protein
complex_num_atoms = frag1_num_atoms + frag2_num_atoms
batch_size = 1
atomic_convnet = AtomConvModel(n_tasks=1,
Expand Down

0 comments on commit 05c7734

Please sign in to comment.