We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, I'm trying to use jacrev to get the jacobians in graph convolution networks, but it seems like I've called the function incorrectly.
jacrev
import torch.nn.functional as F import functorch import torch_geometric from torch_geometric.data import Data class GCN(torch.nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() torch.manual_seed(12345) self.conv1 = torch_geometric.nn.GCNConv(input_dim, hidden_dim, aggr='add') self.conv2 = torch_geometric.nn.GCNConv(hidden_dim, output_dim, aggr='add') def forward(self, x, edge_index): x = self.conv1(x, edge_index) x = x.relu() x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return x adj_matrix = torch.ones(3,3) edge_index = adj_matrix .nonzero().t().contiguous() gcn = GCN(input_dim=5, hidden_dim=64, output_dim=5) N = (128,3, 5) x =torch.randn(N, requires_grad=True) # batch_size:128, node_num:10 , node_feature: 5 graph = Data(x=x, edge_index=edge_index) gcn_out = gcn(graph.x, graph.edge_index)
Then I try to compute the jacobians of the input data x based on the tutorial,
x
jacobian = functorch.vmap(functorch.jacrev(gcn))(graph.x, graph.edge_index)
and get the following error message:
ValueError: vmap: Expected all tensors to have the same size in the mapped dimension, got sizes [128, 2] for the mapped dimension
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi, I'm trying to use
jacrev
to get the jacobians in graph convolution networks, but it seems like I've called the function incorrectly.Then I try to compute the jacobians of the input data
x
based on the tutorial,and get the following error message:
The text was updated successfully, but these errors were encountered: