Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify modules for multialpha #1

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open

modify modules for multialpha #1

wants to merge 11 commits into from

Conversation

akichinguyen
Copy link
Owner

No description provided.

output = {}
tr_keys = []
for i in self.transform.keys():
tr_keys.append(int(eval(i)[0]))
Copy link
Collaborator

@sergeikotelnikov sergeikotelnikov Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need int() here.

@sergeikotelnikov
Copy link
Collaborator

sergeikotelnikov commented Jul 15, 2021

Unfortunately, I cannot comment on non-PR lines within this PR
isaacs/github#284

class GNormBias(nn.Module):
"""Norm-based SE(3)-equivariant nonlinearity with only learned biases."""
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True),
num_layers: int = 0):
"""Initializer.
Args:
fiber: Fiber() of feature multiplicities and types
nonlin: nonlinearity to use everywhere
num_layers: non-negative number of linear layers in fnc
"""
super().__init__()
self.fiber = fiber
self.nonlin = nonlin
self.num_layers = num_layers
# Regularization for computing phase: gradients explode otherwise
self.eps = 1e-12
# Norm mappings: 1 per feature type
self.bias = nn.ParameterDict()
for m, d in self.fiber.structure:
self.bias[str(d)] = nn.Parameter(torch.randn(m).view(1, m))
def __repr__(self):
return f"GNormTFN()"
def forward(self, features, **kwargs):
output = {}
for k, v in features.items():
# Compute the norms and normalized features
# v shape: [...,m , 2*k+1]
norm = v.norm(2, -1, keepdim=True).clamp_min(self.eps).expand_as(v)
phase = v / norm
# Transform on norms
# transformed = self.transform[str(k)](norm[..., 0]).unsqueeze(-1)
transformed = self.nonlin(norm[..., 0] + self.bias[str(k)])
# Nonlinearity on norm
output[k] = (transformed.unsqueeze(-1) * phase).view(*v.shape)
return output

I think we can simplify it:
torch.randn(m).view(1, m) -> torch.randn(1, m, 1)
get rid of:
.expand_as(v)
[..., 0]
.unsqueeze(-1)
.view(*v.shape)
What do you think?

Comment on lines 663 to 665
msg = msg.view(msg.shape[0], -1, 2*d_out+1)

return {f'out{d_out}': msg.view(msg.shape[0], -1, 2*d_out+1)}
Copy link
Collaborator

@sergeikotelnikov sergeikotelnikov Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to do the same tensor reshaping one more time.

msg = msg + torch.matmul(edge, src) #sum over all d_in => prob need to keep this separate, not sum up
msg = msg.view(msg.shape[0], -1, 2*d_out+1)

return {f'out{d_out},{dv_in},{dv_out}': msg.view(msg.shape[0], -1, 2*d_out+1)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to do the same tensor reshaping one more time.

@sergeikotelnikov
Copy link
Collaborator

sergeikotelnikov commented Jul 15, 2021

class GMABSE3(nn.Module):
"""An SE(3)-equivariant multi-headed self-attention module for DGL graphs."""
def __init__(self, f_value: Fiber, f_key: Fiber, n_heads: int):
"""SE(3)-equivariant MAB (multi-headed attention block) layer.
Args:
f_value: Fiber() object for value-embeddings
f_key: Fiber() object for key-embeddings
n_heads: number of heads
"""
super().__init__()
self.f_value = f_value
self.f_key = f_key
self.n_heads = n_heads
self.new_dgl = version.parse(dgl.__version__) > version.parse('0.4.4')
def __repr__(self):
return f'GMABSE3(n_heads={self.n_heads}, structure={self.f_value})'
def udf_u_mul_e(self, d_out):
"""Compute the weighted sum for a single output feature type.
This function is set up as a User Defined Function in DGL.
Args:
d_out: output feature type
Returns:
edge -> node function handle
"""
def fnc(edges):
# Neighbor -> center messages
attn = edges.data['a']
value = edges.data[f'v{d_out}']
# Apply attention weights
msg = attn.unsqueeze(-1).unsqueeze(-1) * value
return {'m': msg}
return fnc
@profile
def forward(self, v, k: Dict=None, q: Dict=None, G=None, **kwargs):
"""Forward pass of the linear layer
Args:
G: minibatch of (homo)graphs
v: dict of value edge-features
k: dict of key edge-features
q: dict of query node-features
Returns:
tensor with new features [B, n_points, n_features_out]
"""
with G.local_scope():
# Add node features to local graph scope
## We use the stacked tensor representation for attention
for m, d in self.f_value.structure:
G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads, m//self.n_heads, 2*d+1) #keep vector shape for different type
G.edata['k'] = fiber2head(k, self.n_heads, self.f_key, squeeze=True) # [edges, heads, channels](?) #concat all types into 1 vector
G.ndata['q'] = fiber2head(q, self.n_heads, self.f_key, squeeze=True) # [nodes, heads, channels](?)
# Compute attention weights
## Inner product between (key) neighborhood and (query) center
G.apply_edges(fn.e_dot_v('k', 'q', 'e'))
## Apply softmax
e = G.edata.pop('e')
if self.new_dgl:
# in dgl 5.3, e has an extra dimension compared to dgl 4.3
# the following, we get rid of this be reshaping
n_edges = G.edata['k'].shape[0]
e = e.view([n_edges, self.n_heads])
e = e / np.sqrt(self.f_key.n_features)
G.edata['a'] = edge_softmax(G, e)
# Perform attention-weighted message-passing
for d in self.f_value.degrees:
G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}'))
output = {}
for m, d in self.f_value.structure:
output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2*d+1)
return output

I think the dot products should be divided by np.sqrt(self.f_key.n_features / self.n_heads).
The same thing applies to GMABSE3_qkv.

@akichinguyen
Copy link
Owner Author

akichinguyen commented Jul 16, 2021 via email

@sergeikotelnikov
Copy link
Collaborator

sergeikotelnikov commented Jul 16, 2021

What do you want to divide it by?
On Thu, Jul 15, 2021 at 6:22 PM Sergei Kotelnikov @.***> wrote:

class GMABSE3(nn.Module):
"""An SE(3)-equivariant multi-headed self-attention module for DGL graphs."""
def __init__(self, f_value: Fiber, f_key: Fiber, n_heads: int):
"""SE(3)-equivariant MAB (multi-headed attention block) layer.
Args:
f_value: Fiber() object for value-embeddings
f_key: Fiber() object for key-embeddings
n_heads: number of heads
"""
super().__init__()
self.f_value = f_value
self.f_key = f_key
self.n_heads = n_heads
self.new_dgl = version.parse(dgl.__version__) > version.parse('0.4.4')
def __repr__(self):
return f'GMABSE3(n_heads={self.n_heads}, structure={self.f_value})'
def udf_u_mul_e(self, d_out):
"""Compute the weighted sum for a single output feature type.
This function is set up as a User Defined Function in DGL.
Args:
d_out: output feature type
Returns:
edge -> node function handle
"""
def fnc(edges):
# Neighbor -> center messages
attn = edges.data['a']
value = edges.data[f'v{d_out}']
# Apply attention weights
msg = attn.unsqueeze(-1).unsqueeze(-1) * value
return {'m': msg}
return fnc
@profile
def forward(self, v, k: Dict=None, q: Dict=None, G=None, **kwargs):
"""Forward pass of the linear layer
Args:
G: minibatch of (homo)graphs
v: dict of value edge-features
k: dict of key edge-features
q: dict of query node-features
Returns:
tensor with new features [B, n_points, n_features_out]
"""
with G.local_scope():
# Add node features to local graph scope
## We use the stacked tensor representation for attention
for m, d in self.f_value.structure:
G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads, m//self.n_heads, 2*d+1) #keep vector shape for different type
G.edata['k'] = fiber2head(k, self.n_heads, self.f_key, squeeze=True) # [edges, heads, channels](?) #concat all types into 1 vector
G.ndata['q'] = fiber2head(q, self.n_heads, self.f_key, squeeze=True) # [nodes, heads, channels](?)
# Compute attention weights
## Inner product between (key) neighborhood and (query) center
G.apply_edges(fn.e_dot_v('k', 'q', 'e'))
## Apply softmax
e = G.edata.pop('e')
if self.new_dgl:
# in dgl 5.3, e has an extra dimension compared to dgl 4.3
# the following, we get rid of this be reshaping
n_edges = G.edata['k'].shape[0]
e = e.view([n_edges, self.n_heads])
e = e / np.sqrt(self.f_key.n_features)
G.edata['a'] = edge_softmax(G, e)
# Perform attention-weighted message-passing
for d in self.f_value.degrees:
G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}'))
output = {}
for m, d in self.f_value.structure:
output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2*d+1)
return output
I think the dot product should be divided by — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub <#1 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AF3S77KYUTRC5T4JFMVHWJTTX5NSZANCNFSM5AHTJ5AA .
-- Thu T. Nguyen

np.sqrt(self.f_key.n_features / self.n_heads)

@sergeikotelnikov
Copy link
Collaborator

sergeikotelnikov commented Feb 13, 2022

cloned_d = torch.clone(G.edata['d'])
if G.edata['d'].requires_grad:
cloned_d.requires_grad_()
log_gradient_norm(cloned_d, 'Basis computation flow')

cloned_d = torch.clone(G.edata['d'])
if G.edata['d'].requires_grad:
cloned_d.requires_grad_()
log_gradient_norm(cloned_d, 'Neural networks flow')

I think cloned_d.requires_grad_() is redundant.

@sergeikotelnikov
Copy link
Collaborator

class BN(nn.Module):
"""SE(3)-equvariant batch/layer normalization"""
def __init__(self, m):
"""SE(3)-equvariant batch/layer normalization
Args:
m: int for number of output channels
"""
super().__init__()
self.bn = nn.LayerNorm(m)
def forward(self, x):
return self.bn(x)

I don't understand why they (partially) call it batch normalization and why they need this function. In essence, it is layer normalization.

def forward(self, features, **kwargs):
output = {}
tr_keys = []
for i in self.transform.keys():
Copy link
Collaborator

@sergeikotelnikov sergeikotelnikov Feb 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#self.transform.keys() is equal to #({d_out},{dv_in},{dv_out}). It is not equal to #d_outs. Should we maybe use self.f_out.degrees instead of tr_keys or tr_keys = {}.

Copy link
Owner Author

@akichinguyen akichinguyen Feb 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good. The tr_keys only contains d_out via int(eval(i)[0]). Ok I saw degrees. Let me change

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested. f_out.degrees and tr_keys contain the same elements.

self.transform = nn.ParameterDict()
for m_out, d_out in self.f_out.structure:
for mv_in, dv_in in self.fv_in.structure:
for mv_out, dv_out in self.fv_out.structure:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use self.fv_in.degrees and self.fv_out.degrees?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can change but is it necessary? they are similar to this except we dont need mv_in and mv_out right?

@sergeikotelnikov
Copy link
Collaborator

class GNormBias(nn.Module):
"""Norm-based SE(3)-equivariant nonlinearity with only learned biases."""
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True),
num_layers: int = 0):
"""Initializer.
Args:
fiber: Fiber() of feature multiplicities and types
nonlin: nonlinearity to use everywhere
num_layers: non-negative number of linear layers in fnc
"""
super().__init__()
self.fiber = fiber
self.nonlin = nonlin
self.num_layers = num_layers

We don't use num_layers.

@sergeikotelnikov
Copy link
Collaborator

output[k] = (transformed.unsqueeze(-1) * phase).view(*v.shape)

.view(*v.shape) is not necessary.

@sergeikotelnikov
Copy link
Collaborator

cur_inpt = m_in * m_in
net = []
for i in range(1, self.num_layers):
net.append(nn.LayerNorm(int(cur_inpt)))

We don't need int() here.

@sergeikotelnikov
Copy link
Collaborator

sign = scalars.sign()
scalars = scalars.abs_().clamp_min(self.eps)
scalars *= sign

Do we need to clamp here?

@akichinguyen
Copy link
Owner Author

output[k] = (transformed.unsqueeze(-1) * phase).view(*v.shape)

.view(*v.shape) is not necessary.

agree. they have the same shape

@akichinguyen
Copy link
Owner Author

sign = scalars.sign()
scalars = scalars.abs_().clamp_min(self.eps)
scalars *= sign

Do we need to clamp here?

No I dont think we need

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants