Skip to content

Commit

Permalink
Robust regressor (deepchem#4160)
Browse files Browse the repository at this point in the history
* Robust regressor

* Robust regressor pytorch class:

* Adds more tasks to the architecture test for robust_regressor

* Adds more tasks to architecture similarity test

* Fixes mypy error

* Added PyTorch robustmultitask regressor model to rst

* Rebased and modified for robust multitask classifier and regressor

* Fixed __init__.py

* Rebased and fixed models.rst
  • Loading branch information
spellsharp authored Nov 25, 2024
1 parent 4a29b98 commit cbd54e8
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 4 deletions.
2 changes: 1 addition & 1 deletion deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from deepchem.models.torch_models.unet import UNet, UNetModel
from deepchem.models.torch_models.graphconvmodel import _GraphConvTorchModel, GraphConvModel
from deepchem.models.torch_models.smiles2vec import Smiles2Vec, Smiles2VecModel
from deepchem.models.torch_models.robust_multitask import RobustMultitask, RobustMultitaskClassifier
from deepchem.models.torch_models.inceptionv3 import InceptionV3Model, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux, BasicConv2d
from deepchem.models.torch_models.robust_multitask import RobustMultitask, RobustMultitaskClassifier, RobustMultitaskRegressor
try:
from deepchem.models.torch_models.dmpnn import DMPNN, DMPNNModel
from deepchem.models.torch_models.gnn import GNN, GNNHead, GNNModular
Expand Down
70 changes: 68 additions & 2 deletions deepchem/models/torch_models/robust_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import logging
from typing import List, Tuple, Callable, Literal, Union
from typing import Sequence as SequenceCollection
from deepchem.utils.typing import OneOrMany, ActivationFn
from deepchem.models.torch_models.torch_model import TorchModel
from deepchem.models import losses
from deepchem.utils.typing import OneOrMany, ActivationFn
from deepchem.metrics import to_one_hot
import datetime

Expand Down Expand Up @@ -209,6 +209,7 @@ def forward(
task_outputs.append(task_output)

output = torch.stack(task_outputs, dim=1)

if self.mode == 'classification':
if self.n_tasks == 1:
logits = output.view(-1, self.n_classes)
Expand Down Expand Up @@ -252,7 +253,7 @@ def _get_activation_class(self, activation_name: ActivationFn) -> Callable:
return activation_name
else:
raise ValueError(
f"Invalid activation function: {activation_name}. Only activations of type nn.Module"
f"Invalid activation function: {activation_name}. Only activations of type torch.nn.Module (torch.nn.functional activations are not supported yet!!)"
)


Expand Down Expand Up @@ -410,3 +411,68 @@ def default_generator(self,
y_b = to_one_hot(y_b.flatten(), self.n_classes).reshape(
-1, self.n_tasks, self.n_classes)
yield ([X_b], [y_b], [w_b])


class RobustMultitaskRegressor(TorchModel):
"""Implements a neural network for robust multitasking.
The key idea of this model is to have bypass layers that feed
directly from features to task output. This might provide some
flexibility toroute around challenges in multitasking with
destructive interference.
References
----------
This technique was introduced in [1]_
.. [1] Ramsundar, Bharath, et al. "Is multitask deep learning practical for pharma?." Journal of chemical information and modeling 57.8 (2017): 2068-2076.
"""

def __init__(self,
n_tasks,
n_features,
layer_sizes=[1000],
weight_init_stddevs: OneOrMany[float] = 0.02,
bias_init_consts: OneOrMany[float] = 1.0,
weight_decay_penalty: float = 0.0,
weight_decay_penalty_type: str = "l2",
dropouts: OneOrMany[float] = 0.5,
activation_fns: OneOrMany[ActivationFn] = nn.ReLU(),
bypass_layer_sizes=[100],
bypass_weight_init_stddevs=[.02],
bypass_bias_init_consts=[1.0],
bypass_dropouts=[0.5],
**kwargs):

loss = losses.L2Loss()
output_types = ['prediction']
n_classes = 1

model = RobustMultitask(
n_tasks=n_tasks,
n_features=n_features,
layer_sizes=layer_sizes,
mode='regression',
weight_init_stddevs=weight_init_stddevs,
bias_init_consts=bias_init_consts,
weight_decay_penalty=weight_decay_penalty,
weight_decay_penalty_type=weight_decay_penalty_type,
activation_fns=activation_fns,
dropouts=dropouts,
n_classes=n_classes,
bypass_layer_sizes=bypass_layer_sizes,
bypass_weight_init_stddevs=bypass_weight_init_stddevs,
bypass_bias_init_consts=bypass_bias_init_consts,
bypass_dropouts=bypass_dropouts)

self.shared_layers = model.shared_layers
self.bypass_layers = model.bypass_layers
self.output_layers = model.output_layers

super(RobustMultitaskRegressor,
self).__init__(model,
loss,
output_types=output_types,
regularization_loss=model.regularization_loss,
**kwargs)
130 changes: 129 additions & 1 deletion deepchem/models/torch_models/tests/test_robust_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn as nn
import deepchem as dc
from deepchem.models.torch_models import RobustMultitask, RobustMultitaskClassifier
from deepchem.models.torch_models import RobustMultitask, RobustMultitaskClassifier, RobustMultitaskRegressor
has_torch = True
except ModuleNotFoundError:
has_torch = False
Expand Down Expand Up @@ -218,3 +218,131 @@ def test_robust_multitask_classifier_reload():
reloaded_preds = reloaded_model.predict_on_batch(X_new)

assert np.all(orig_preds == reloaded_preds), "Predictions are not the same"


def test_robust_multitask_regressor_construction():
"""Test that RobustMultiTaskRegressor Model can be constructed without crash.
"""

model = RobustMultitaskRegressor(
n_tasks=1,
n_features=100,
layer_sizes=[128, 256],
)

assert model is not None


@pytest.mark.torch
def test_robust_multitask_regression_forward():
"""Test that the forward pass of RobustMultiTask Model can be executed without crash
and that the output has the correct value.
"""

n_tasks = n_tasks_tf
n_features = n_features_tf
layer_sizes = layer_sizes_tf

torch_model = RobustMultitaskRegressor(
n_tasks=n_tasks,
n_features=n_features,
layer_sizes=layer_sizes,
)

weights = np.load(
os.path.join(os.path.dirname(__file__), "assets",
"tensorflow_robust_multitask_regressor_weights.npz"))

move_weights(torch_model, weights)

input_x = weights["input"]
output = weights["output"]

# Inference using TorchModel's predict() method works with NumpyDataset only. Hence we need to convert our numpy arrays to NumpyDataset.
y = np.random.rand(input_x.shape[0], 1)
w = np.ones((input_x.shape[0], 1))
ids = np.arange(input_x.shape[0])
input_x = dc.data.NumpyDataset(input_x, y, w, ids)

torch_out = torch_model.predict(input_x)
assert np.allclose(output, torch_out,
atol=1e-4), "Predictions are not close"


@pytest.mark.torch
def test_robust_multitask_regressor_overfit():
"""Test that the model can overfit simple regression datasets."""
n_samples = 10
n_features = 5
n_tasks = 3

np.random.seed(123)
torch.manual_seed(123)
X = np.random.rand(n_samples, n_features)
y = np.random.rand(n_samples, n_tasks).astype(np.float32)
dataset = dc.data.NumpyDataset(X, y)

regression_metric = dc.metrics.Metric(dc.metrics.mean_squared_error,
task_averager=np.mean,
mode="regression")

model = RobustMultitaskRegressor(
n_tasks,
n_features,
layer_sizes=[128, 256],
dropouts=0.2,
weight_init_stddevs=0.02,
bias_init_consts=0.0,
)

model.fit(dataset, nb_epoch=300)

scores = model.evaluate(dataset, [regression_metric])
assert scores[regression_metric.name] < 0.05, "Failed to overfit"


@pytest.mark.torch
def test_robust_multitask_regressor_reload():
"""Test that the model can be reloaded from disk."""

n_samples = 20
n_features = 5
n_tasks = 3

# Generate dummy dataset
np.random.seed(123)
torch.manual_seed(123)
ids = np.arange(n_samples)
X = np.random.rand(n_samples, n_features)
y = np.random.rand(n_samples, n_tasks).astype(np.float32)
w = np.ones((n_samples, n_tasks))
dataset = dc.data.NumpyDataset(X, y, w, ids)

model_dir = tempfile.mkdtemp()

orig_model = RobustMultitaskRegressor(n_tasks,
n_features,
layer_sizes=[128, 256],
dropouts=0.2,
alpha_init_stddevs=0.02,
weight_init_stddevs=0.02,
bias_init_consts=0.0,
model_dir=model_dir)
orig_model.fit(dataset, nb_epoch=200)

reloaded_model = RobustMultitaskRegressor(n_tasks,
n_features,
layer_sizes=[128, 256],
dropouts=0.2,
alpha_init_stddevs=0.02,
weight_init_stddevs=0.02,
bias_init_consts=0.0,
model_dir=model_dir)

reloaded_model.restore()

X_new = np.random.rand(n_samples, n_features)
orig_preds = orig_model.predict_on_batch(X_new)
reloaded_preds = reloaded_model.predict_on_batch(X_new)

assert np.all(orig_preds == reloaded_preds), "Predictions are not the same"
6 changes: 6 additions & 0 deletions docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,12 @@ RobustMultitaskClassifier
.. autoclass:: deepchem.models.torch_models.RobustMultitaskClassifier
:members:

RobustMultitaskRegressor
------------------------

.. autoclass:: deepchem.models.torch_models.RobustMultitaskRegressor
:members:

Density Functional Theory Model - XCModel
-----------------------------------------

Expand Down

0 comments on commit cbd54e8

Please sign in to comment.