Skip to content

Equivalent of FullyConnectedTensorProduct using Flax modules #50

Answered by ameya98
Chronum94 asked this question in Q&A
Discussion options

You must be logged in to vote

I think you need something like this:

class WrappedFCTensorProduct(nn.Module):

    irreps_out: e3nn.Irreps

    @nn.compact
    def __call__(self, input_1: e3nn.IrrepsArray, input_2: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
        output = e3nn.tensor_product(input_1, input_2)
        output = e3nn.flax.Linear(irreps_out=self.irreps_out)(output)
        return output

and you can initialize and call the module like this:

input_1 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
input_2 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
tp = WrappedFCTensorProduct(irreps_out="2x0e + 5x1e + 8x2e + 11x3e")
params = tp.init(jax.random.PRNGKey(0), input_1, input_2)
output = tp.apply(params

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
4 replies
@Chronum94
Comment options

@ameya98
Comment options

ameya98 Dec 11, 2023
Collaborator

@mariogeiger
Comment options

@Chronum94
Comment options

Comment options

You must be logged in to vote
1 reply
@Chronum94
Comment options

Answer selected by Chronum94
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants