diff --git a/nngeometry/generator/jacobian/jacobian.py b/nngeometry/generator/jacobian/jacobian.py index 93b611a..7bec40a 100644 --- a/nngeometry/generator/jacobian/jacobian.py +++ b/nngeometry/generator/jacobian/jacobian.py @@ -312,14 +312,21 @@ def get_jacobian(self, examples): self.start = 0 for d in loader: inputs = d[0] - inputs.requires_grad = True + differentiate_wrt = [] + if inputs.dtype in [ + torch.float16, + torch.float32, + torch.float64, + ]: + inputs.requires_grad = True + differentiate_wrt.append(inputs) bs = inputs.size(0) output = self.function(*d).view(bs, self.n_output).sum(dim=0) for self.i_output in range(self.n_output): retain_graph = self.i_output < self.n_output - 1 torch.autograd.grad( output[self.i_output], - [inputs], + differentiate_wrt, retain_graph=retain_graph, only_inputs=True, ) diff --git a/nngeometry/layercollection.py b/nngeometry/layercollection.py index 9fcd363..9d24dec 100644 --- a/nngeometry/layercollection.py +++ b/nngeometry/layercollection.py @@ -25,7 +25,8 @@ class LayerCollection: "Affine1d", "ConvTranspose2d", "Conv1d", - "LayerNorm" + "LayerNorm", + "Embedding", ] def __init__(self, layers=None): @@ -151,6 +152,10 @@ def _module_to_layer(mod): return LayerNormLayer( normalized_shape=mod.normalized_shape, bias=(mod.bias is not None) ) + elif mod_class == "Embedding": + return EmbeddingLayer( + embedding_dim=mod.embedding_dim, num_embeddings=mod.num_embeddings + ) def numel(self): """ @@ -292,6 +297,22 @@ def __eq__(self, other): ) +class EmbeddingLayer(AbstractLayer): + def __init__(self, num_embeddings, embedding_dim): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = Parameter(num_embeddings, embedding_dim) + + def numel(self): + return self.weight.numel() + + def __eq__(self, other): + return ( + self.num_embeddings == other.num_embeddings + and self.embedding_dim == other.embedding_dim + ) + + class BatchNorm1dLayer(AbstractLayer): def __init__(self, num_features): self.num_features = num_features diff --git a/tests/tasks.py b/tests/tasks.py index 2943806..43527e7 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as tF from torch.nn.modules.conv import ConvTranspose2d -from torch.utils.data import DataLoader, Subset +from torch.utils.data import DataLoader, Subset, TensorDataset from torchvision import datasets, transforms from nngeometry.layercollection import LayerCollection @@ -184,6 +184,33 @@ def output_fn(input, target): return (train_loader, layer_collection, net.parameters(), net, output_fn, 2) +class EmbeddingNet(nn.Module): + def __init__(self): + super(EmbeddingNet, self).__init__() + self.embedding_layer = nn.Embedding(10, 3) + + def forward(self, x): + output = self.embedding_layer(x) + print(output.size()) + return output.sum(axis=1) + + +def get_embedding_task(): + train_set = TensorDataset( + torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]), torch.LongTensor([2, 0]) + ) + train_loader = DataLoader(dataset=train_set, batch_size=2, shuffle=False) + net = EmbeddingNet() + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(input) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + + class LinearConvNet(nn.Module): def __init__(self): super(LinearConvNet, self).__init__() diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index 8eb0f83..c9f0e1d 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -7,6 +7,7 @@ get_conv_gn_task, get_conv_skip_task, get_conv_task, + get_embedding_task, get_fullyconnect_affine_task, get_fullyconnect_cosine_task, get_fullyconnect_onlylast_task, @@ -35,6 +36,7 @@ from nngeometry.object.vector import PVector, random_fvector, random_pvector linear_tasks = [ + get_embedding_task, get_linear_fc_task, get_linear_conv_task, get_batchnorm_fc_linear_task,