It is a fork of NVIDIA's SE(3)-Transformer implementation. I made some minor modifications, including
- removal of torch.cuda.nvtx.nvtx_range
- addition of the
nonlinearity
argument toNormSE3
,SE3Transformer
, and so on. - addition of some basic network implementations using SE(3)-Transformer.
pip install git+http://github.com/huhlim/SE3Transformer
- Install DGL library with CUDA support
# This is an example with cudatoolkit=11.3.
# Set a proper cudatoolkit version that is compatible with your CUDA drivier and DGL library.
conda install dgl -c dglteam/label/cu113
# or
pip install dgl -f https://data.dgl.ai/wheels/cu113/repo.html
- Install this package
pip install git+http://github.com/huhlim/SE3Transformer
se3_transformer.LinearModule
:LinearSE3
andNormSE3
SE3Transformer/se3_transformer/snippets.py
Lines 14 to 64 in b74f707
class LinearModule(nn.Module): """ Operates only within a node, so it basically applies nn.Linear to every node. """ def __init__( self, fiber_in: Fiber, fiber_hidden: Fiber, fiber_out: Fiber, n_layer: Optional[int] = 2, use_norm: Optional[bool] = True, nonlinearity: Optional[nn.Module] = nn.ReLU(), **kwargs, ): """ arguments: - fiber_in: Fiber, numbers of input features - fiber_hidden: Fiber, numbers of intermediate features - fiber_out: Fiber, numbers of output features - n_layer: int, the number linear layers - use_norm: bool, if True, NormSE3 will be inserted before a LinearSE3 layer - nonlinearity: activation function for NormSE3 """ super().__init__() # linear_module = [] # if n_layer >= 2: if use_norm: linear_module.append(NormSE3(Fiber(fiber_in), nonlinearity=nonlinearity)) linear_module.append(LinearSE3(Fiber(fiber_in), Fiber(fiber_hidden))) # for _ in range(n_layer - 2): if use_norm: linear_module.append(NormSE3(Fiber(fiber_hidden), nonlinearity=nonlinearity)) linear_module.append(LinearSE3(Fiber(fiber_hidden), Fiber(fiber_hidden))) # if use_norm: linear_module.append(NormSE3(Fiber(fiber_hidden), nonlinearity=nonlinearity)) linear_module.append(LinearSE3(Fiber(fiber_hidden), Fiber(fiber_out))) else: if use_norm: linear_module.append(NormSE3(Fiber(fiber_init), nonlinearity=nonlinearity)) linear_module.append(LinearSE3(Fiber(fiber_init), Fiber(fiber_out))) # self.linear_module = nn.Sequential(*linear_module) def forward(self, x): return self.linear_module(x) se3_transformer.InteractionModule
: A wrapper of SE3TransformerSE3Transformer/se3_transformer/snippets.py
Lines 67 to 118 in b74f707
class InteractionModule(nn.Module): """ Utilization of SE3-Transformer block """ def __init__( self, fiber_in: Fiber, fiber_hidden: Fiber, fiber_out: Fiber, fiber_edge: Optional[Fiber] = Fiber({}), n_layer: Optional[int] = 2, n_head: Optional[int] = 2, use_norm: Optional[bool] = True, use_layer_norm: Optional[bool] = True, nonlinearity: Optional[nn.Module] = nn.ReLU(), low_memory: Optional[bool] = True, **kwargs, ): """ arguments: - fiber_in: Fiber, numbers of input features - fiber_hidden: Fiber, numbers of intermediate features - fiber_out: Fiber, numbers of output features - fiber_edge: Fiber, numbers of edge features - n_layer: int, the number linear layers - n_head: int, the number of attention heads - use_norm: bool, if True, NormSE3 will be inserted before a LinearSE3 layer - use_layer_norm: bool, if True, LayerNorm will be used between MLP (radial) - nonlinearity: activation function for NormSE3 - low_memory: bool, if True, gradient checkpoint will be activated for ConvSE3 """ super().__init__() self.graph_module = SE3Transformer( num_layers=n_layer, fiber_in=fiber_in, fiber_hidden=fiber_hidden, fiber_out=fiber_out, num_heads=n_head, channels_div=2, fiber_edge=fiber_edge, norm=use_norm, use_layer_norm=use_layer_norm, nonlinearity=nonlinearity, low_memory=low_memory, ) def forward(self, batch: dgl.DGLGraph, node_feats: torch.Tensor, edge_feats: torch.Tensor): out = self.graph_module(batch, node_feats=node_feats, edge_feats=edge_feats) return out
- LinearModule + InteractionModule
SE3Transformer/example/example.py
Lines 1 to 84 in b74f707
#!/usr/bin/env python import torch import torch.nn as nn import dgl import sys from se3_transformer import Fiber, LinearModule, InteractionModule from se3_transformer.utils import degree_to_dim class Model(nn.Module): def __init__(self, config): super().__init__() # self.linear = LinearModule(**config["linear"]) self.interact = InteractionModule(**config["interact"]) def forward(self, batch: dgl.DGLGraph): edge_feats = {} node_feats = {str(degree): batch.ndata[f"node_feat_{degree}"] for degree in [0, 1]} # out = self.linear(node_feats) out = self.interact(batch, node_feats=out, edge_feats=edge_feats) return out def create_random_example(n_point, fiber_in): # create a fully connected graph edges = [[], []] for i in range(n_point): for j in range(n_point): edges[0].append(i) edges[1].append(j) edges = tuple([torch.as_tensor(x) for x in edges]) g = dgl.graph(edges) # pos = torch.randn((n_point, 3)) g.ndata["pos"] = pos[:, None, :] for fiber in fiber_in: dim = degree_to_dim(fiber.degree) g.ndata[f"node_feat_{fiber.degree}"] = torch.randn((n_point, fiber.channels, dim)) # src, dst = g.edges() g.edata["rel_pos"] = pos[dst] - pos[src] return g def main(): config = {} # config["linear"] = {} config["linear"]["fiber_in"] = Fiber([(0, 8), (1, 4)]) config["linear"]["fiber_hidden"] = Fiber([(0, 16), (1, 8)]) config["linear"]["fiber_out"] = Fiber([(0, 16), (1, 8)]) config["linear"]["n_layer"] = 2 config["linear"]["use_norm"] = True config["linear"]["nonlinearity"] = nn.ReLU() # config["interact"] = {} config["interact"]["fiber_in"] = Fiber([(0, 16), (1, 8)]) config["interact"]["fiber_hidden"] = Fiber([(0, 16), (1, 8)]) config["interact"]["fiber_out"] = Fiber([(0, 2), (1, 1)]) config["interact"]["fiber_edge"] = Fiber({}) config["interact"]["n_layer"] = 2 config["interact"]["n_head"] = 2 config["interact"]["use_norm"] = True config["interact"]["use_layer_norm"] = True config["interact"]["nonlinearity"] = nn.ReLU() config["interact"]["low_memory"] = True # model = Model(config) # batch = create_random_example(n_point=10, fiber_in=config["linear"]["fiber_in"]) out = model(batch) print(out) print(out["0"].size()) # = (n_point, 2, 1) print(out["1"].size()) # = (n_point, 1, 3) if __name__ == "__main__": main() - A fully connected graph is created with random coordinates
- Input features: 8 scalars and 4 vectors
- Output features: 2 scalars and 1 vector
- LinearModule: two
LinearSE3
withNormSE3
, returns 16 scalars and 8 vectors. - InteractionModule: two layers of attention blocks with two heads, takes the output of the LinearModule as
node_feats
and noedge_feats
.