diff --git a/deepchem/models/torch_models/__init__.py b/deepchem/models/torch_models/__init__.py index 1475f74d37..3dfaf68146 100644 --- a/deepchem/models/torch_models/__init__.py +++ b/deepchem/models/torch_models/__init__.py @@ -34,8 +34,8 @@ from deepchem.models.torch_models.unet import UNet, UNetModel from deepchem.models.torch_models.graphconvmodel import _GraphConvTorchModel, GraphConvModel from deepchem.models.torch_models.smiles2vec import Smiles2Vec, Smiles2VecModel -from deepchem.models.torch_models.robust_multitask import RobustMultitask, RobustMultitaskClassifier from deepchem.models.torch_models.inceptionv3 import InceptionV3Model, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux, BasicConv2d +from deepchem.models.torch_models.robust_multitask import RobustMultitask, RobustMultitaskClassifier, RobustMultitaskRegressor try: from deepchem.models.torch_models.dmpnn import DMPNN, DMPNNModel from deepchem.models.torch_models.gnn import GNN, GNNHead, GNNModular diff --git a/deepchem/models/torch_models/robust_multitask.py b/deepchem/models/torch_models/robust_multitask.py index eee467fad3..07c422e4b7 100644 --- a/deepchem/models/torch_models/robust_multitask.py +++ b/deepchem/models/torch_models/robust_multitask.py @@ -4,9 +4,9 @@ import logging from typing import List, Tuple, Callable, Literal, Union from typing import Sequence as SequenceCollection +from deepchem.utils.typing import OneOrMany, ActivationFn from deepchem.models.torch_models.torch_model import TorchModel from deepchem.models import losses -from deepchem.utils.typing import OneOrMany, ActivationFn from deepchem.metrics import to_one_hot import datetime @@ -209,6 +209,7 @@ def forward( task_outputs.append(task_output) output = torch.stack(task_outputs, dim=1) + if self.mode == 'classification': if self.n_tasks == 1: logits = output.view(-1, self.n_classes) @@ -252,7 +253,7 @@ def _get_activation_class(self, activation_name: ActivationFn) -> Callable: return activation_name else: raise ValueError( - f"Invalid activation function: {activation_name}. Only activations of type nn.Module" + f"Invalid activation function: {activation_name}. Only activations of type torch.nn.Module (torch.nn.functional activations are not supported yet!!)" ) @@ -410,3 +411,68 @@ def default_generator(self, y_b = to_one_hot(y_b.flatten(), self.n_classes).reshape( -1, self.n_tasks, self.n_classes) yield ([X_b], [y_b], [w_b]) + + +class RobustMultitaskRegressor(TorchModel): + """Implements a neural network for robust multitasking. + + The key idea of this model is to have bypass layers that feed + directly from features to task output. This might provide some + flexibility toroute around challenges in multitasking with + destructive interference. + + References + ---------- + This technique was introduced in [1]_ + + .. [1] Ramsundar, Bharath, et al. "Is multitask deep learning practical for pharma?." Journal of chemical information and modeling 57.8 (2017): 2068-2076. + + """ + + def __init__(self, + n_tasks, + n_features, + layer_sizes=[1000], + weight_init_stddevs: OneOrMany[float] = 0.02, + bias_init_consts: OneOrMany[float] = 1.0, + weight_decay_penalty: float = 0.0, + weight_decay_penalty_type: str = "l2", + dropouts: OneOrMany[float] = 0.5, + activation_fns: OneOrMany[ActivationFn] = nn.ReLU(), + bypass_layer_sizes=[100], + bypass_weight_init_stddevs=[.02], + bypass_bias_init_consts=[1.0], + bypass_dropouts=[0.5], + **kwargs): + + loss = losses.L2Loss() + output_types = ['prediction'] + n_classes = 1 + + model = RobustMultitask( + n_tasks=n_tasks, + n_features=n_features, + layer_sizes=layer_sizes, + mode='regression', + weight_init_stddevs=weight_init_stddevs, + bias_init_consts=bias_init_consts, + weight_decay_penalty=weight_decay_penalty, + weight_decay_penalty_type=weight_decay_penalty_type, + activation_fns=activation_fns, + dropouts=dropouts, + n_classes=n_classes, + bypass_layer_sizes=bypass_layer_sizes, + bypass_weight_init_stddevs=bypass_weight_init_stddevs, + bypass_bias_init_consts=bypass_bias_init_consts, + bypass_dropouts=bypass_dropouts) + + self.shared_layers = model.shared_layers + self.bypass_layers = model.bypass_layers + self.output_layers = model.output_layers + + super(RobustMultitaskRegressor, + self).__init__(model, + loss, + output_types=output_types, + regularization_loss=model.regularization_loss, + **kwargs) diff --git a/deepchem/models/torch_models/tests/test_robust_multitask.py b/deepchem/models/torch_models/tests/test_robust_multitask.py index 6604102bf9..c96c4ee497 100644 --- a/deepchem/models/torch_models/tests/test_robust_multitask.py +++ b/deepchem/models/torch_models/tests/test_robust_multitask.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import deepchem as dc - from deepchem.models.torch_models import RobustMultitask, RobustMultitaskClassifier + from deepchem.models.torch_models import RobustMultitask, RobustMultitaskClassifier, RobustMultitaskRegressor has_torch = True except ModuleNotFoundError: has_torch = False @@ -218,3 +218,131 @@ def test_robust_multitask_classifier_reload(): reloaded_preds = reloaded_model.predict_on_batch(X_new) assert np.all(orig_preds == reloaded_preds), "Predictions are not the same" + + +def test_robust_multitask_regressor_construction(): + """Test that RobustMultiTaskRegressor Model can be constructed without crash. + """ + + model = RobustMultitaskRegressor( + n_tasks=1, + n_features=100, + layer_sizes=[128, 256], + ) + + assert model is not None + + +@pytest.mark.torch +def test_robust_multitask_regression_forward(): + """Test that the forward pass of RobustMultiTask Model can be executed without crash + and that the output has the correct value. + """ + + n_tasks = n_tasks_tf + n_features = n_features_tf + layer_sizes = layer_sizes_tf + + torch_model = RobustMultitaskRegressor( + n_tasks=n_tasks, + n_features=n_features, + layer_sizes=layer_sizes, + ) + + weights = np.load( + os.path.join(os.path.dirname(__file__), "assets", + "tensorflow_robust_multitask_regressor_weights.npz")) + + move_weights(torch_model, weights) + + input_x = weights["input"] + output = weights["output"] + + # Inference using TorchModel's predict() method works with NumpyDataset only. Hence we need to convert our numpy arrays to NumpyDataset. + y = np.random.rand(input_x.shape[0], 1) + w = np.ones((input_x.shape[0], 1)) + ids = np.arange(input_x.shape[0]) + input_x = dc.data.NumpyDataset(input_x, y, w, ids) + + torch_out = torch_model.predict(input_x) + assert np.allclose(output, torch_out, + atol=1e-4), "Predictions are not close" + + +@pytest.mark.torch +def test_robust_multitask_regressor_overfit(): + """Test that the model can overfit simple regression datasets.""" + n_samples = 10 + n_features = 5 + n_tasks = 3 + + np.random.seed(123) + torch.manual_seed(123) + X = np.random.rand(n_samples, n_features) + y = np.random.rand(n_samples, n_tasks).astype(np.float32) + dataset = dc.data.NumpyDataset(X, y) + + regression_metric = dc.metrics.Metric(dc.metrics.mean_squared_error, + task_averager=np.mean, + mode="regression") + + model = RobustMultitaskRegressor( + n_tasks, + n_features, + layer_sizes=[128, 256], + dropouts=0.2, + weight_init_stddevs=0.02, + bias_init_consts=0.0, + ) + + model.fit(dataset, nb_epoch=300) + + scores = model.evaluate(dataset, [regression_metric]) + assert scores[regression_metric.name] < 0.05, "Failed to overfit" + + +@pytest.mark.torch +def test_robust_multitask_regressor_reload(): + """Test that the model can be reloaded from disk.""" + + n_samples = 20 + n_features = 5 + n_tasks = 3 + + # Generate dummy dataset + np.random.seed(123) + torch.manual_seed(123) + ids = np.arange(n_samples) + X = np.random.rand(n_samples, n_features) + y = np.random.rand(n_samples, n_tasks).astype(np.float32) + w = np.ones((n_samples, n_tasks)) + dataset = dc.data.NumpyDataset(X, y, w, ids) + + model_dir = tempfile.mkdtemp() + + orig_model = RobustMultitaskRegressor(n_tasks, + n_features, + layer_sizes=[128, 256], + dropouts=0.2, + alpha_init_stddevs=0.02, + weight_init_stddevs=0.02, + bias_init_consts=0.0, + model_dir=model_dir) + orig_model.fit(dataset, nb_epoch=200) + + reloaded_model = RobustMultitaskRegressor(n_tasks, + n_features, + layer_sizes=[128, 256], + dropouts=0.2, + alpha_init_stddevs=0.02, + weight_init_stddevs=0.02, + bias_init_consts=0.0, + model_dir=model_dir) + + reloaded_model.restore() + + X_new = np.random.rand(n_samples, n_features) + orig_preds = orig_model.predict_on_batch(X_new) + reloaded_preds = reloaded_model.predict_on_batch(X_new) + + assert np.all(orig_preds == reloaded_preds), "Predictions are not the same" diff --git a/docs/source/api_reference/models.rst b/docs/source/api_reference/models.rst index a8d6abfd5b..e2da091d03 100644 --- a/docs/source/api_reference/models.rst +++ b/docs/source/api_reference/models.rst @@ -558,6 +558,12 @@ RobustMultitaskClassifier .. autoclass:: deepchem.models.torch_models.RobustMultitaskClassifier :members: +RobustMultitaskRegressor +------------------------ + +.. autoclass:: deepchem.models.torch_models.RobustMultitaskRegressor + :members: + Density Functional Theory Model - XCModel -----------------------------------------